diff --git a/.appveyor.yml b/.appveyor.yml index efc98f5559..f4f56fa159 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,3 +1,5 @@ +skip_branch_with_pr: true + environment: matrix: - LIB_TYPE: shared @@ -39,10 +41,11 @@ build_script: - bash -lc "cd /c/projects/blis && ./configure %CONFIGURE_OPTS% --enable-threading=%THREADING% --enable-arg-max-hack --prefix=/c/blis %CONFIG%" - bash -lc "cd /c/projects/blis && mingw32-make -j4 V=1" - bash -lc "cd /c/projects/blis && mingw32-make install" -- ps: Compress-Archive -Path C:\blis -DestinationPath C:\blis.zip +- 7z a C:\blis.zip C:\blis - ps: Push-AppveyorArtifact C:\blis.zip test_script: +# "make checkblas" does not work with shared linking Windows due to inability to override xerbla_ - if [%LIB_TYPE%]==[shared] set "TEST_TARGET=checkblis-fast" - if [%LIB_TYPE%]==[static] set "TEST_TARGET=check" - bash -lc "cd /c/projects/blis && mingw32-make %TEST_TARGET% -j4 V=1" diff --git a/.dir-locals.el b/.dir-locals.el new file mode 100644 index 0000000000..fccb205020 --- /dev/null +++ b/.dir-locals.el @@ -0,0 +1,9 @@ +;; First (minimal) attempt at configuring Emacs CC mode for the BLIS +;; layout requirements. +((c-mode . ((c-file-style . "stroustrup") + (c-basic-offset . 4) + (comment-start . "// ") + (comment-end . "") + (indent-tabs-mode . t) + (tab-width . 4) + (parens-require-spaces . nil)))) diff --git a/.gitignore b/.gitignore index de56af2a17..a24fe2b0ea 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ config.mk bli_config.h +bli_addon.h # -- monolithic headers -- @@ -43,7 +44,12 @@ include/*/*.h # -- misc. -- # BLIS testsuite output file -output.testsuite +output.testsuite.* # BLAS test output files out.* + +# GTAGS database +GPATH +GRTAGS +GTAGS diff --git a/.travis.yml b/.travis.yml index dbe3c41d81..6603ca2f35 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,83 +1,83 @@ language: c sudo: required -dist: trusty -env: - global: - secure: "Ty3PM1xGhXwxfJG6YyY9bUZyXzw98ekHxQEqU9VnrMXTZb28IxfocPCXHjL34r9HTGosO5Pmierhal1Cs3ZKE5ZAJqJhCfck+kwlH21Uay5CNYglDtSmy2qxtbbDG4AxpEZ1UKlIZr1pNh/x+pRemSmnMEnQp/E7QJqdkhm4+aMX2bWKyLPtrdL+B9QXLVT2nT6/Fw3i05aBhpcFJpSPfvYX2KoCZYdJOSKcKci4T8nAfP/c0olkz+jAkBZxZFgO9Ptrt/lvHtVPrkh5o29GvHg2i/4vucbsMltoxlV31/2eYpdr17Ngtt41MMVn2fHV4lVhLmENc04nlm084fBtg73T6b8hNy5JlcA44xI/UrPJsQAJ+0A0ds9BbBQKPxOmaF/O8WGXhwiwdKT6DGS9lj05f3S+yZfeNE3pQhLEcvwXLO5SW3VvKXMj0t/lZyG+XCkvFjD7KEPQV4g+BZc2zzD9TwDx3ydn8Uzd6zZlq1erQUzCnODP24wuwfrNP8nqxFYG0VtI8oZW62IC9U2hcnAF5QNXXW3yDYD65k3BHbigfI28gu9iO9G8RxOglR27J7Whdqkqw3AMRaqyHt2tdbz7tM2dLZ0EatT5m8esjC+LP4EshW9C59jP2U9vJ/94YEgOfwiqk8+e6fL/7dJvOumbwu1RclRI9DS88PPYb3Q=" +dist: focal +branches: + only: + - master + - dev + - amd matrix: include: - # full testsuite (all tests except for mixed datatype) + # full testsuite (all tests + mixed datatype (gemm_nn only) + salt + SDE + OOT) - os: linux compiler: gcc - env: OOT=0 TEST=1 SDE=0 THR="none" CONF="auto" - # mixed-datatype testsuite (gemm_nn only) - - os: linux - compiler: gcc - env: OOT=0 TEST=MD SDE=0 THR="none" CONF="auto" - # salt testsuite (fast set of operations+parameters) - - os: linux - compiler: gcc - env: OOT=0 TEST=SALT SDE=0 THR="none" CONF="auto" - # test x86_64 ukrs with SDE - - os: linux - compiler: gcc - env: OOT=0 TEST=0 SDE=1 THR="none" CONF="x86_64" + env: OOT=1 TEST=ALL SDE=1 THR="none" CONF="x86_64" \ + PACKAGES="gcc-9 binutils" # openmp build - os: linux compiler: gcc - env: OOT=0 TEST=0 SDE=0 THR="openmp" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="openmp" CONF="auto" \ + PACKAGES="gcc-9 binutils" # pthreads build - os: linux compiler: gcc - env: OOT=0 TEST=0 SDE=0 THR="pthreads" CONF="auto" - # out-of-tree build - - os: linux - compiler: gcc - env: OOT=1 TEST=0 SDE=0 THR="none" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="pthreads" CONF="auto" \ + PACKAGES="gcc-9 binutils" # clang build - os: linux compiler: clang - env: OOT=0 TEST=0 SDE=0 THR="none" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="auto" + # There seems to be some difficulty installing 2 Clang toolchains of different versions. + # Use the TravisCI default. + # PACKAGES="clang-8 binutils" # macOS with system compiler (clang) - os: osx compiler: clang - env: OOT=0 TEST=1 SDE=0 THR="none" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="auto" # cortexa15 build and fast testsuite (qemu) - os: linux compiler: arm-linux-gnueabihf-gcc env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="cortexa15" \ - PACKAGES="gcc-arm-linux-gnueabihf qemu-system-arm qemu-user" \ + CC=arm-linux-gnueabihf-gcc CXX=arm-linux-gnueabihf-g++ \ + PACKAGES="gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf libc6-dev-armhf-cross qemu-system-arm qemu-user" \ TESTSUITE_WRAPPER="qemu-arm -cpu cortex-a15 -L /usr/arm-linux-gnueabihf/" # cortexa57 build and fast testsuite (qemu) - os: linux compiler: aarch64-linux-gnu-gcc env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="cortexa57" \ - PACKAGES="gcc-aarch64-linux-gnu qemu-system-arm qemu-user" \ + CC=aarch64-linux-gnu-gcc CXX=aarch64-linux-gnu-g++ \ + PACKAGES="gcc-aarch64-linux-gnu g++-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ + TESTSUITE_WRAPPER="qemu-aarch64 -L /usr/aarch64-linux-gnu/" + # Apple M1 (firestorm) build and fast testsuite (qemu) + - os: linux + compiler: aarch64-linux-gnu-gcc + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="firestorm" \ + CC=aarch64-linux-gnu-gcc CXX=aarch64-linux-gnu-g++ \ + PACKAGES="gcc-aarch64-linux-gnu g++-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ TESTSUITE_WRAPPER="qemu-aarch64 -L /usr/aarch64-linux-gnu/" + # armsve build and fast testsuite (qemu) + - os: linux + compiler: aarch64-linux-gnu-gcc-10 + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="armsve" \ + CC=aarch64-linux-gnu-gcc-10 CXX=aarch64-linux-gnu-g++-10 \ + PACKAGES="gcc-10-aarch64-linux-gnu g++-10-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ + TESTSUITE_WRAPPER="qemu-aarch64 -cpu max,sve=true,sve512=true -L /usr/aarch64-linux-gnu/" install: -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo rm -f /usr/bin/as; fi -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo ln -s /usr/lib/binutils-2.26/bin/as /usr/bin/as; fi -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo rm -f /usr/bin/ld; fi -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo ln -s /usr/lib/binutils-2.26/bin/ld /usr/bin/ld; fi -- if [ "$CC" = "gcc" ] && [ "$TRAVIS_OS_NAME" = "linux" ]; then export CC="gcc-6"; fi -- if [ -n "$PACKAGES" ]; then sudo apt-get install -y $PACKAGES; fi -addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - gcc-6 - - binutils-2.26 - - clang +- if [ "$CC" = "gcc" ] && [ "$TRAVIS_OS_NAME" = "linux" ]; then export CC="gcc-9"; fi +- if [ -n "$PACKAGES" ] && [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo apt-get install -y $PACKAGES; fi script: - export DIST_PATH=. - pwd - if [ $OOT -eq 1 ]; then export DIST_PATH=`pwd`; mkdir ../oot; cd ../oot; chmod -R a-w $DIST_PATH; fi - pwd -- $DIST_PATH/configure -t $THR CC=$CC $CONF +- $DIST_PATH/configure -p `pwd`/../install -t $THR CC=$CC $CONF - pwd - ls -l - $CC --version - make -j 2 +- make install +- $DIST_PATH/travis/cxx/cxx-test.sh $DIST_PATH $(ls -1 include) +# Qemu SVE is failing sgemmt in some cases. Skip as this issue is not observed on real chip (A64fx). +- if [ "$CONF" = "armsve" ]; then sed -i 's/.*\.*/0/' $DIST_PATH/testsuite/input.operations.fast; fi - if [ "$TEST" != "0" ]; then travis_wait 30 $DIST_PATH/travis/do_testsuite.sh; fi -- if [ $SDE -eq 1 ] && [ "$TRAVIS_PULL_REQUEST" = "false" ] ; then travis_wait 30 $DIST_PATH/travis/do_sde.sh; fi +- if [ "$SDE" = "1" ]; then travis_wait 30 $DIST_PATH/travis/do_sde.sh; fi diff --git a/CHANGELOG b/CHANGELOG index 784c9f5fd5..13eaa52caa 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,10 +1,5516 @@ -commit e0408c3ca3d53bc8e6fedac46ea42c86e06c922d (HEAD -> master, tag: 0.5.1) +commit 8535b3e11d2297854991c4272932ce4974dda629 (HEAD -> master, tag: 0.8.1) +Author: Field G. Van Zee +Date: Mon Mar 22 17:42:33 2021 -0500 + + Version file update (0.8.1) + +commit e56d9f2d94ed247696dda2cbf94d2ca05c7fc089 (origin/master, origin/HEAD) +Author: Field G. Van Zee +Date: Mon Mar 22 17:40:50 2021 -0500 + + ReleaseNotes.md update in advance of next version. + +commit ca83f955d45814b7d84f53933cdb73323c0dea2c +Author: Field G. Van Zee +Date: Mon Mar 22 17:21:21 2021 -0500 + + CREDITS file update. + +commit 57ef61f6cdb86957f67212aa59407f2f8e7f3d1a +Merge: bf1b578e e7a4a8ed +Author: Field G. Van Zee +Date: Fri Mar 19 13:05:43 2021 -0500 + + Merge branch 'master' of github.com:flame/blis + +commit bf1b578ea32ea1c9dbf7cb3586969e8ae89aa5ef +Author: Field G. Van Zee +Date: Fri Mar 19 13:03:17 2021 -0500 + + Reduced KC on skx from 384 to 256. + + Details: + - Reduced the KC cache blocksize for double real on the skx subconfig + from 384 to 256. The maximum (extended) KC was also reduced + accordingly from 480 to 320. Thanks to Tze Meng Low for suggesting + this change. + +commit e7a4a8edc940942357e8e4c4594383a29a962f93 +Author: Nicholai Tukanov +Date: Wed Mar 17 19:43:31 2021 -0500 + + Fix calculation of new pb size (#487) + + Details: + - Added missing parentheses to the i8 and i4 instantiations of the + GENERIC_GEMM macro in sandbox/power10/generic_gemm.c. + +commit 4493cf516e01aba82642a43abe350943ba458fe2 +Author: Field G. Van Zee +Date: Mon Mar 15 13:12:49 2021 -0500 + + Redefined BLIS_NUM_ARCHS to update automatically. + + Details: + - Changed BLIS_NUM_ARCHS from a cpp macro definition to the last enum + value in the arch_t enum. This means that it no longer needs to get + updated manually whenever new subconfigurations are added to BLIS. + Also removed the explicit initial index assigment of 0 from the + first enum value, which was unnecessary due to how the C language + standard mandates indexing of enum values. Thanks to Devin Matthews + for originally submitting this as a PR in #446. + - Updated docs/ConfigurationHowTo.md to reflect the aforementioned + change. + +commit a4b73de84cdffcbe5cf71969a0f7f0f8202b3510 +Author: Field G. Van Zee +Date: Fri Mar 12 17:12:27 2021 -0600 + + Disabled _self() and _equal() in bli_pthread API. + + Details: + - Disabled the _self() and _equal() extensions to the bli_pthread API + introduced in d479654. These functions were disabled after I realized + that they aren't actually needed yet. Thanks to Devin Matthews for + helping me reason through the appropriate consumer code that will + appear in BLIS (eventually) in a future commit. (Also, I could never + get the Windows branch to link properly in clang builds in AppVeyor. + See the comment I left in the code, and #485, for more info.) + +commit f9d604679d8715bc3e79a8630268446889b51388 +Author: Field G. Van Zee +Date: Thu Mar 11 16:57:55 2021 -0600 + + Added _self() and _equal() to bli_pthread API. + + Details: + - Expanded the bli_pthread API to include equivalents to pthread_self() + and pthread_equal(). Implemented these two functions for all three cpp + branches present within bli_pthread.c: systemless, Windows, and + Linux/BSD. + +commit fa9b3c8f6b3d5717f19832362104413e1a86dfb0 +Author: Field G. Van Zee +Date: Thu Mar 11 15:13:51 2021 -0600 + + Shuffled code in Windows branch of bli_pthreads.c. + + Details: + - Reordered the definitions in the cpp branch in bli_pthreads.c that + defines the bli_pthreads API in terms of Windows API calls. Also added + missing comments that mark sections of the API, which brings the code + into harmony with other cpp branches (as well as bli_pthread.h). + +commit 95d4f3934d806b3563f6648d57a4e381d747caf5 +Author: Field G. Van Zee +Date: Thu Mar 11 13:50:40 2021 -0600 + + Moved cpp macro redef of strerror_r to bli_env.c. + + Details: + - Relocated the _MSC_VER-guarded cpp macro re-definition of strerror_r + (in terms of strerror_s) from bli_thread.h to bli_env.c. It was + likely left behind in bli_thread.h in a previous commit, when code + that now resides in bli_env.c was moved from bli_thread.c. (I couldn't + find any other instance of strerror_r being used in BLIS, so I moved + the #define directly to bli_env.c rather than place it in bli_env.h.) + The code that uses strerror_r is currently disabled, though, so this + commit should have no affect on BLIS. + +commit 8a3066c315358d45d4f5b710c54594455f9e8fc6 +Author: Field G. Van Zee +Date: Tue Mar 9 17:52:59 2021 -0600 + + Relocated gemmsup_ref general stride handling. + + Details: + - Moved the logic that checks for general stridedness in any of the + matrix operands in a gemmsup problem. The logic previously resided + near the top of bli_gemmsup_int(), which is the thread entry point + for the parallel region of the current gemmsup implementation. The + problem with this setup was that the code would attempt to reject + problems with any general-strided operands by returning BLIS_FAILURE, + and that return value was then being ignored by the l3_sup thread + decorator, which unconditionally returns BLIS_SUCCESS. To solve this + issue, rather than try to manage n return values, one from each of n + threads, I simply moved the logic into bli_gemmsup_ref(). I didn't + move it any higher (e.g. bli_gemmsup()) because I still want the + logic to be part of the current gemmsup handler implementation. That + is, perhaps someone else will create a different handler, and that + author wants to handle general stride differently. (We don't want to + force them into a particular way of handling general stride.) + - Removed the general stride handling from bli_gemmtsup_int(), even + though this function is inoperative for now. + - This commit addresses issue #484. Thanks to RuQing Xu for reporting + this issue. + +commit 670bc7b60f6065893e8ec1bebd2fc9e5ba710dff +Author: Nicholai Tukanov +Date: Fri Mar 5 13:53:43 2021 -0600 + + Add low-precision POWER10 gemm kernels (#467) + + Details: + - This commit adds a new BLIS sandbox that (1) provides implementations + based on low-precision gemm kernels, and (2) extends the BLIS typed + API for those new implementations. Currently, these new kernels can + only be used for the POWER10 microarchitecture; however, they may + provide a template for developing similar kernels for other + microarchitectures (even those beyond POWER), as changes would likely + be limited to select places in the microkernel and possibly the + packing routines. The new low-precision operations that are now + supported include: shgemm, sbgemm, i16gemm, i8gemm, i4gemm. For more + information, refer to the POWER10.md document that is included in + 'sandbox/power10'. + +commit b8dcc5bc75a746807d6f8fa22dc2123c98396bf5 (origin/dev, origin/amd, dev, amd) +Author: RuQing Xu +Date: Tue Mar 2 06:58:24 2021 +0800 + + Fixed typed API definition for gemmt (#476) + + Details: + - Fixed incorrect definition and prototype of bli_?gemmt() in + frame/3/bli_l3_tapi.c and .h, respectively. gemmt was previously + defined identically to gemm, which was wrong because it did not + take into account the uplo property of C. + - Fixed incorrect API documentation for her2k/syr2k in BLISTypedAPI.md. + Specifically, the document erroneously listed only a single transab + parameter instead of transa and transb. + +commit a0e4fe2340a93521e1b1a835a96d0f26dec8406a +Author: Ilknur +Date: Tue Mar 2 02:06:56 2021 +0400 + + Fixed double free() in level1v example (#482) + + Details: + - In exampls/tapi/00level1v.c, pointer 'z' was being freed twice and + pointer 'a' was not being freed at all. This commit correctly frees + each pointer exactly once. + +commit f5871c7e06a75799251d6b55a8a5fbfa1a92cf95 +Author: Field G. Van Zee +Date: Sun Feb 28 17:03:57 2021 -0600 + + Added complex asm packm kernels for 'haswell' set. + + Details: + - Implemented assembly-based packm kernels for single- and double- + precision complex domain (c and z) and housed them in the 'haswell' + kernel set. This means c3xk, c8xk, z3xk, and z4xk are now all + optimized. + - Registered the aforementioned packm kernels in the haswell, zen, + and zen2 subconfigs. + - Minor modifications to the corresponding s and d packm kernels that + were introduced in 426ad67. + - Thanks to AMD, who originally contributed the double-precision real + packm kernels (d6xk and d8xk), upon which these complex kernels are + partially based. + +commit 426ad679f55264e381eb57a372632b774320fb85 +Author: Field G. Van Zee +Date: Sat Feb 27 18:39:56 2021 -0600 + + Added assembly packm kernels for 'haswell' set. + + Details: + - Implemented assembly-based packm kernels for single- and double- + precision real domain (s and d) and housed them in the 'haswell' + kernel set. This means s6xk, s16xk, d6xk, and d8xk are now all + optimized. + - Registered the aforementioned packm kernels in the haswell, zen, + and zen2 subconfigs. + - Thanks to AMD, who originally contributed the double-precision real + packm kernels (d6xk and d8xk), which I have now tweaked and used to + create comparable single-precision real kernels (s6xk and s16xk). + +commit f50c1b7e5886d29efe134e1994d05af9949cd4b6 +Merge: 8f39aea1 b3953b93 +Author: Devin Matthews +Date: Mon Feb 1 11:55:51 2021 -0600 + + Merge pull request #473 from ajaypanyala/pkgconfig + + build: generate pkgconfig file + +commit 8f39aea11f80a805b66cff4b4dc5e72727ea461d +Merge: f8db9fb3 2a815d5b +Author: Field G. Van Zee +Date: Sat Jan 30 17:59:56 2021 -0600 + + Merge branch 'dev' + +commit f8db9fb33b48844d6b47fdef699625bd9197745a +Author: Field G. Van Zee +Date: Thu Jan 28 08:04:52 2021 -0600 + + Fixed missing parentheses in README.md Citations. + +commit b3953b938eee59f79b4a4162ba583a5cb59fa34e +Author: Ajay Panyala +Date: Tue Jan 12 17:07:04 2021 -0800 + + drop CFLAGS in the generated pkgconfig file + +commit b02d9376bac31c1a1c7916f44c4946277a1425e2 +Author: Ajay Panyala +Date: Mon Jan 11 20:50:01 2021 -0800 + + add datadir + +commit d8d8deeb6d8b84adb7ae5fdb88c6dd4f06624a76 +Author: Ajay Panyala +Date: Mon Jan 11 17:47:50 2021 -0800 + + generate pkgconfig file + +commit 8c65411c7c8737248a6f054ffa0ce008c95cb515 +Merge: 328b4f88 874c3f04 +Author: Devin Matthews +Date: Mon Jan 11 16:01:45 2021 -0600 + + Merge pull request #471 from flame/fix-470 + + Fix kernel-to-config mapping for intel64 + +commit 874c3f04ece9af4d8fdf0e2713e21a259c117656 +Author: Devin Matthews +Date: Fri Jan 8 13:56:30 2021 -0600 + + Update configure + + Choose last sub-config in the kernel-to-config map if the config list doesn't contain the name of the kernel set. E.g. for "zen: skx knl haswell" pick "haswell" instead of "skx" which was chosen previously. Fixes #470. + +commit 2a815d5b365d934cb351b2f2a8cd1366e997b2e1 +Author: Field G. Van Zee +Date: Mon Jan 4 18:03:39 2021 -0600 + + Support trsm pre-inversion in 1m, bb, ref kernels. + + Details: + - Expanded support for disabling trsm diagonal pre-inversion to other + microkernel types, including the reference microkernel as well as the + kernel implementations for 1m and the pre-broadcast B (bb) format used + by the power9 subconfig. This builds on the 'haswell' and 'penryn' + kernel support added in 7038bba. Thanks to Bhaskar Nallani for + reminding me, in #461 (post-closure), that 1m support was missing from + that commit. + - Removed cpp branch of ref_kernels/3/bli_trsm_ref.c that contained the + omp simd implementation after making a stripped-down copy in 'old'. + This code has been disabled for some time and it seemed better suited + to rot away out of sight rather than clutter up a file that is already + cluttered by the presence of lower and upper versions. + - Minor comment update to bli_ind_init(). + +commit c3ed2cbb9f60100fc9beb2a9d75476de9f711dc5 +Author: Field G. Van Zee +Date: Mon Jan 4 16:16:32 2021 -0600 + + Enable 1m only if real domain ukr is not reference. + + Details: + - Previously, BLIS would automatically enable use of the 1m method + for a given precision if the complex domain microkernel was a + reference kernel. This commit adds an additional constraint so that + 1m is only enabled if the corresponding real domain microkernel is + NOT reference. That is, BLIS now forgos use of 1m if both the real and + complex domain kernels are reference implementations. Note that this + does not prevent 1m from being enabled manually under those + conditions; it only means that 1m will not be enabled automatically + at initialization-time. + +commit ed50c947385ba3b0b5d550015f38f7f0a31755c0 +Merge: 0cef09aa 328b4f88 +Author: Field G. Van Zee +Date: Mon Jan 4 14:31:44 2021 -0600 + + Merge branch 'master' into dev + +commit 328b4f8872b4bca9a53d2de8c6e285f3eb13d196 +Author: Devin Matthews +Date: Wed Dec 30 17:54:18 2020 -0600 + + Shared object (dylib) was not built correctly for partial build. + + The SO build rule used $? instead of $^. Observed on macOS, not sure if it affected Linux or not. + +commit ae6ef66ef824da9bc6348bf9d1b588cd4f2ded9b +Author: Devin Matthews +Date: Wed Dec 30 17:34:55 2020 -0600 + + bli_diag_offset_with_trans had wrong return type. Fixes #468. + +commit ebcf197fb86fdd0a864ea928140752bc2462e8c6 +Merge: 472f138c 21aa67e1 +Author: Devin Matthews +Date: Sat Dec 5 22:26:27 2020 -0600 + + Merge pull request #466 from isuruf/patch-3 + + fix cc_vendor for crosstool-ng toolchains + +commit 21aa67e11cebbc5a6dd7c6353154256294df3c33 +Author: Isuru Fernando +Date: Sat Dec 5 21:59:13 2020 -0600 + + fix cc_vendor for crosstool-ng toolchains + +commit 472f138cb927b7259126ebb9c68919cfcc7a4ea3 +Author: Field G. Van Zee +Date: Sat Dec 5 14:13:52 2020 -0600 + + Fixed typo in README.md to CodingConventions.md. + +commit 0cef09aa92208441a656bf097f197ea8e22b533b +Author: Field G. Van Zee +Date: Fri Dec 4 16:40:59 2020 -0600 + + Consolidated code in level-3 _front() functions. + + Details: + - Reduced a code segment that appears in all of the bli_*_front() + functions except for bli_gemm_front(). Previously, the code looked + like this (taken from bli_herk_front()): + + if ( bli_cntx_method( cntx ) == BLIS_NAT ) + { + bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); + bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &ah_local ); + } + else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) + { + pack_t schema_a = bli_cntx_schema_a_block( cntx ); + pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &ah_local ); + } + + This code segment is part of a sort-of-hack that allows us to + communicate the pack schemas into the level-3 thread decorator, which + needs them so that they can be passed into bli_l3_cntl_create_if(), + where the control tree is created. However, the first conditional case + above is unnecessary because the second case is fully generalized. + That is, even in the native case, the context contains correct, + queryable schemas. Thus, these code segments were reduced to something + like: + + pack_t schema_a = bli_cntx_schema_a_block( cntx ); + pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &ah_local ); + + There's always a small chance that the seemingly unnecessary code + in the first branch case has some special use that is not apparent to + me, but the testsuite's default input parameters seem to think this + commit will be fine. + +commit 7038bbaa05484141195822291cf3ba88cbce4980 +Author: Field G. Van Zee +Date: Fri Dec 4 16:08:15 2020 -0600 + + Optionally disable trsm diagonal pre-inversion. + + Details: + - Implemented a configure-time option, --disable-trsm-preinversion, that + optionally disables the pre-inversion of diagonal elements of the + triangular matrix in the trsm operation and instead uses division + instructions within the gemmtrsm microkernels. Pre-inversion is + enabled by default. When it is disabled, performance may suffer + slightly, but numerical robustness should improve for certain + pathological cases involving denormal (subnormal) numbers that would + otherwise result in overflow in the pre-inverted value. Thanks to + Bhaskar Nallani for reporting this issue via #461. + - Added preprocessor macro guards to bli_trsm_cntl.c as well as the + gemmtrsm microkernels for 'haswell' and 'penryn' kernel sets pursuant + to the aforementioned feature. + - Added macros to frame/include/bli_x86_asm_macros.h related to division + instructions. + +commit 78aee79452cce2691c40f05b3632bdfc122300af +Author: Field G. Van Zee +Date: Wed Dec 2 13:02:36 2020 -0600 + + Allow amaxv testsuite module to run with dim = 0. + + Details: + - Exit early from libblis_test_amaxv_check() when the vector dimension + (length) of x is 0. This allows the module to run when the testsuite + driver passes in a problem size of 0. Thanks to Meghana Vankadari for + alerting us to this issue via #459. + - Note: All other testsuite modules appear to work with problem sizes + of 0, except for the microkernel modules. I chose not to "fix" those + modules because a failure (or segmentation fault, as happens in this + case) is actually meaningful in that it alerts the developer that some + microkernels cannot be used with k = 0. Specifically, the 'haswell' + kernel set contains microkernels that preload elements of B. Those + microkernels would need to be restructured to avoid preloading in + order to support usage when k = 0. + +commit 92d2b12a44ee0990c22735472aeaf1c17deb2d9b +Author: Field G. Van Zee +Date: Wed Dec 2 13:02:00 2020 -0600 + + Fixed obscure testsuite gemmt dependency bug. + + Details: + - Fixed a bug in the gemmt testsuite module that only manifested when + testing of gemmt is enabled but testing of gemv is disabled. The bug + was due to a copy-paste error dating back to the introduction of gemmt + in 88ad841. + +commit b43dae9a5d2f078c9bbe07079031d6c00a68b7de +Author: Field G. Van Zee +Date: Tue Dec 1 16:44:38 2020 -0600 + + Fixed copy-paste bugs in edge-case sup kernels. + + Details: + - Fixed bugs in two sup kernels, bli_dgemmsup_rv_haswell_asm_1x6() and + bli_dgemmsup_rd_haswell_asm_1x4(), which involved extraneous assembly + instructions that were left over from when the kernels were first + written. These instructions would cause segmentation faults in some + situations where extra memory was not allocated beyond the end of + the matrix buffers. Thanks to Kiran Varaganti for reporting these + bugs and to Bhaskar Nallani for identifying the cause and solution. + +commit 11dfc176a3c422729f453f6c23204cf023e9954d +Author: Field G. Van Zee +Date: Tue Dec 1 19:51:27 2020 +0000 + + Reorganized thread auto-factorization logic. + + Details: + - Reorganized logic of bli_thread_partition_2x2() so that the primary + guts were factored out into "fast" and "slow" variants. Then added + logic to the "fast" variant that allows for more optimal thread + factorizations in some situations where there is at least one factor + of 2. + - Changed BLIS_THREAD_RATIO_M from 2 to 1 in bli_kernel_macro_defs.h and + added comments to that file describing BLIS_THREAD_RATIO_? and + BLIS_THREAD_MAX_?R. + - In bli_family_zen.h and bli_family_zen2.h, preprocessed out several + macros not used in vanilla BLIS and removed the unused macro + BLIS_ENABLE_ZEN_BLOCK_SIZES from the former file. + - Disabled AMD's small matrix handling entry points in bli_syrk_front.c + and bli_trsm_front.c. (These branches of small matrix handling have + not been reviewed by vanilla BLIS developers.) + - Added commented-out calls printf() to bli_rntm.c. + - Whitespace changes to bli_thread.c. + +commit 6d3bafacd7aa7ad198762b39490876c172bfbbcb +Author: Devin Matthews +Date: Sat Nov 28 17:17:56 2020 -0600 + + Update BuildSystem.md + + Add git version >= 1.8.5 requirement (see #462). + +commit 64856ea5a61b01d585750815788b6a775f729647 +Author: Field G. Van Zee +Date: Mon Nov 23 16:54:51 2020 -0600 + + Auto-reduce (by default) prime numbers of threads. + + Details: + - When requesting multithreaded parallelism by specifying the total + number of threads (whether it be via environment variable, globally at + runtime, or locally at runtime), reduce the number of threads actually + used by one if the original value (a) is prime and (b) exceeds a + minimum threshold defined by the macro BLIS_NT_MAX_PRIME, which is set + to 11 by default. If, when specifying the total number of threads (and + not the individual ways of parallelism for each loop), prime numbers + of threads are desired, this feature may be overridden by defining the + BLIS_ENABLE_AUTO_PRIME_NUM_THREADS macro in the bli_family_*.h that + corresponds to the configuration family targeted at configure-time. + (For now, there is no configure option(s) to control this feature.) + Thanks to Jeff Diamond for suggesting this change. + - Defined a new function in bli_thread.c, bli_is_prime(), that returns a + bool that determines whether an integer is prime. This function is + implemented in terms of existing functions in bli_thread.c. + - Updated docs/Multithreading.md to document the above feature, along + with unrelated minor edits. + +commit 55933b6ff6b9b8a12041715f42bba06273d84b74 +Author: Field G. Van Zee +Date: Fri Nov 20 10:39:32 2020 -0600 + + Added missing attribution to docs/ReleaseNotes.md. + +commit e310f57b4b29fbfee479e0f9fe2040851efdec4f +Author: Field G. Van Zee +Date: Thu Nov 19 13:33:37 2020 -0600 + + CHANGELOG update (0.8.0) + +commit 9b387f6d5a010969727ec583c0cdd067a5274ed8 (tag: 0.8.0) +Author: Field G. Van Zee +Date: Thu Nov 19 13:33:37 2020 -0600 + + Version file update (0.8.0) + +commit 2928ec750d3a3e1e5d55de5b57ddc04e9d0bd796 +Author: Field G. Van Zee +Date: Wed Nov 18 18:31:35 2020 -0600 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated docs/ReleaseNotes.md in preparation for next version. + +commit b9899bedff6854639468daa7a973bb14ca131a74 +Author: Field G. Van Zee +Date: Wed Nov 18 16:52:41 2020 -0600 + + CREDITS file update. + +commit 9bb23e6c2a44b77292a72093938ab1ee6e6cc26a +Author: Field G. Van Zee +Date: Mon Nov 16 15:55:45 2020 -0600 + + Added support for systemless build (no pthreads). + + Details: + - Added a configure option, --[enable|disable]-system, which determines + whether the modest operating system dependencies in BLIS are included. + The most notable example of this on Linux and BSD/OSX is the use of + POSIX threads to ensure thread safety for when application-level + threads call BLIS. When --disable-system is given, the bli_pthreads + implementation is dummied out entirely, allowing the calling code + within BLIS to remain unchanged. Why would anyone want to build BLIS + like this? The motivating example was submitted via #454 in which a + user wanted to build BLIS for a simulator such as gem5 where thread + safety may not be a concern (and where the operating system is largely + absent anyway). Thanks to Stepan Nassyr for suggesting this feature. + - Another, more minor side effect of the --disable-system option is that + the implementation of bli_clock() unconditionally returns 0.0 instead + of the time elapsed since some fixed point in the past. The reasoning + for this is that if the operating system is truly minimal, the system + function call upon which bli_clock() would normally be implemented + (e.g. clock_gettime()) may not be available. + - Refactored preprocess-guarded code in bli_pthread.c and bli_pthread.h + to remove redundancies. + - Removed old comments and commented #include of "bli_pthread_wrap.h" + from bli_system.h. + - Documented bli_clock() and bli_clock_min_diff() in BLISObjectAPI.md + and BLISTypedAPI.md, with a note that both are non-functional when + BLIS is configured with --disable-system. + +commit 88ad84143414644df4c56733b1cf91a36bfacaf8 +Author: Field G. Van Zee +Date: Sat Nov 14 09:39:48 2020 -0600 + + Squash-merge 'pr' into 'squash'. (#457) + + Merged contributions from AMD's AOCL BLIS (#448). + + Details: + - Added support for level-3 operation gemmt, which performs a gemm on + only the lower or upper triangle of a square matrix C. For now, only + the conventional/large code path will be supported (in vanilla BLIS). + This was accomplished by leveraging the existing variant logic for + herk. However, some of the infrastructure to support a gemmtsup is + included in this commit, including + - A bli_gemmtsup() front-end, similar to bli_gemmsup(). + - A bli_gemmtsup_ref() reference handler function. + - A bli_gemmtsup_int() variant chooser function (with variant calls + commented out). + - Added support for inducing complex domain gemmt via the 1m method. + - Added gemmt APIs to the BLAS and CBLAS compatiblity layers. + - Added gemmt test module to testsuite. + - Added standalone gemmt test driver to 'test' directory. + - Documented gemmt APIs in BLISObjectAPI.md and BLISTypedAPI.md. + - Added a C++ template header (blis.hh) containing a BLAS-inspired + wrapper to a set of polymorphic CBLAS-like function wrappers defined + in another header (cblas.hh). These two headers are installed if + running the 'install' target with INSTALL_HH is set to 'yes'. (Also + added a set of unit tests that exercise blis.hh, although they are + disabled for now because they aren't compatible with out-of-tree + builds.) These files now live in the 'vendor' top-level directory. + - Various updates to 'zen' and 'zen2' subconfigurations, particularly + within the context initialization functions. + - Added s and d copyv, setv, and swapv kernels to kernels/zen/1, and + various minor updates to dotv and scalv kernels. Also added various + sup kernels contributed by AMD to kernels/zen/3. However, these + kernels are (for now) not yet used, in part because they caused + AppVeyor clang failures, and also because I have not found time to + review and vet them. + - Output the python found during configure into the definition of PYTHON + in build/config.mk (via build/config.mk.in). + - Added early-return checks (A, B, or C with zero dimension; alpha = 0) + to bli_gemm_front.c. + - Implemented explicit beta = 0 handling in for the sgemm ukernel in + bli_gemm_armv7a_int_d4x4.c, which was previously missing. This latent + bug surfaced because the gemmt module verifies its computation using + gemm with its beta parameter set to zero, which, on a cortexa15 system + caused the gemm kernel code to unconditionally multiply the + uninitialized C data by beta. The C matrix likely contained + non-numeric values such as NaN, which then would have resulted in a + false failure. + - Fixed a bug whereby the implementation for bli_herk_determine_kc(), + in bli_l3_blocksize.c, was inadvertantly being defined in terms of + helper functions meant for trmm. This bug was probably harmless since + the trmm code should have also done the right thing for herk. + - Used cpp macros to neutralize the various AOCL_DTL_TRACE_ macros in + kernels/zen/3/bli_gemm_small.c since those macros are not used in + vanilla BLIS. + - Added cpp guard to definition of bli_mem_clear() in bli_mem.h to + accommodate C++'s stricter type checking. + - Added cpp guard to test/*.c drivers that facilitate compilation on + Windows systems. + - Various whitespace changes. + +commit 234b8b0cf48f1ee965bd7999b291fc7add3b9a54 +Author: Field G. Van Zee +Date: Thu Nov 12 19:11:16 2020 -0600 + + Increased dotxaxpyf testsuite thresholds. + + Details: + - Increased the test thresholds used by the dotxaxpyf testsuite module + by a factor of five in order to avoid residuals that unnecessarily + fall in the MARGINAL range. This commit should fix #455. Thanks to + @nagsingh for reporting this issue. + +commit ed612dd82c50063cfd23576a6b2465213d31b14b +Author: Field G. Van Zee +Date: Sat Nov 7 13:09:42 2020 -0600 + + Updated README.md with sgemmsup blurb. + + Details: + - Added an entry to the "What's New" section of the README.md to + announce the availability of sgemmsup. + +commit e14424f55b15d67e8d18384aea45a11b9b772e02 +Merge: 0cfe1aac eccdd75a +Author: Field G. Van Zee +Date: Sat Nov 7 13:02:50 2020 -0600 + + Merge branch 'dev' + +commit 0cfe1aac222008a78dff3ee03ef5183413936706 +Author: Field G. Van Zee +Date: Fri Oct 30 17:10:36 2020 -0500 + + Relocated operation index to ToC in API docs. + + Details: + - Moved the "Operation index" section of both the BLISObjectAPI.md and + BLISTypedAPI.md docs to appear immediately after the table of contents + of each document. This allows the reader to quickly jump to the + documentation for any operation without having to scroll through much + of the document (when rendered via a web browser). + - Fixed a mistake in the BLISObjectAPI.md for the setd operation, which + does *not* observe the diag property of its matrix argument. Thanks to + Jeff Diamond for reporting this. + +commit 2a0682f8e5998be536da313525292f0da6193147 +Author: Field G. Van Zee +Date: Sun Oct 18 18:04:03 2020 -0500 + + Implemented runtime subconfig selection (#451). + + Details: + - Implemented support for the user manually overriding the automatic + subconfiguration selection that happens at runtime. This override + can be requested by setting the BLIS_ARCH_TYPE environment variable. + The variable must be set to the arch_t id (as enumerated in + bli_type_defs.h) corresponding to the desired subconfiguration. If a + value outside this enumerated range is given, BLIS will abort with an + error message. If the value is in the valid range but corresponds to a + subconfiguration that was not activated at configure-time/compile-time, + BLIS will abort with a (different) error message. Thanks to decandia50 + for suggesting this feature via issue #451. + - Defined a new function bli_gks_lookup_id to return the address of an + internal data structure within the gks. If this address is NULL, then + it indicates that the subconfig corresponding to the arch_t id passed + into the function was not compiled into BLIS. This function is used + in the second of the two abort scenarios described above. + - Defined the enumerated error code BLIS_UNINITIALIZED_GKS_CNTX, which + is returned for the latter of the two abort scenarios mentioned above, + along with a corresponding error message and a function to perform + the error check. + - Added cpp macro branching to bli_env.c to support compilation of the + auto-detect.x executable during configure-time. This cpp branch is + similar to the cpp code already found in bli_arch.c and bli_cpuid.c. + - Cleaned up the auto_detect() function to facilitate easier maintenance + going forward. Also added a convenient debug switch that outputs the + compilation command for the auto-detect.x executable and exits. + +commit eccdd75a2d8a0c46e91e94036179c49aa5fa601c +Author: Field G. Van Zee +Date: Fri Oct 9 15:44:16 2020 -0500 + + Whitespace tweak in docs/PerformanceSmall.md. + +commit 7677e9ba60ac27496e3421c2acc7c239e3f860e9 +Merge: addcd46b a0849d39 +Author: Field G. Van Zee +Date: Fri Oct 9 15:41:25 2020 -0500 + + Merge branch 'dev' of github.com:flame/blis into dev + +commit addcd46b0559d401aa7d33d4c7e6f63f5313a8e0 +Author: Field G. Van Zee +Date: Fri Oct 9 15:41:09 2020 -0500 + + Added Epyc 7742 Zen2 ("Rome") sup perf results. + + Details: + - Added single-threaded and multithreaded sup performance results to + docs/PerformanceSmall.md for both sgemm and dgemm. These results were + gathered on an Epyc 7742 "Rome" server featuring AMD's Zen2 + microarchitecture. Special thanks to Jeff Diamond for facilitating + access to the system via the Oracle Cloud. + - Updates to octave scripts in test/sup/octave for use with Octave 5.2 + and for use with subplot_tight(). + - Minor updates to octave scripts in test/3/octave. + - Renamed files containing the previous Zen performance results for + consistency with the new results. + - Decreased line thickness slightly in large/conventional Zen2 graphs. + I'm done tweaking those this time. Really. + - Added missing line regarding eigen header installation for each + microarchitecture section. + +commit a0849d390d04067b82af937cda8191b049b98915 +Author: Field G. Van Zee +Date: Fri Oct 9 20:22:17 2020 +0000 + + Register l3 sup kernels in zen2 subconfig. + + Details: + - Registered full suite of sgemm and dgemm sup millikernels, blocksizes, + and crossover thresholds in bli_cntx_init_zen2.c. + - Minor updates to test/sup/runme.sh for running on Zen2 Epyc 7742 + system. + +commit d98368c32d5fbfaab8966ee331d9bcb5c4fe7a59 +Author: Field G. Van Zee +Date: Thu Oct 8 19:05:51 2020 -0500 + + Another tweak to line thickness of Zen2 graphs. + +commit 1855dfbdaafa37892b36c97fd317fd5d8da76676 +Author: Field G. Van Zee +Date: Thu Oct 8 19:01:00 2020 -0500 + + Tweaked line thickness in Zen2 graphs once more. + + Details: + - Decreased (relative to previous commit) line thickness in recent Zen2 + graphs. + +commit 0991611e7ed82889c53a5c3f1ef1d49552c50d61 +Author: Field G. Van Zee +Date: Thu Oct 8 18:54:49 2020 -0500 + + Increased line thickness in recent Zen2 graphs. + + Details: + - Increased the width of the lines in the graphs introduced in 74ec6b8. + +commit 8273cbacd7799e9af59e5320d66055f2f5d9cb31 +Author: Field G. Van Zee +Date: Wed Oct 7 14:51:33 2020 -0500 + + README.md, docs/FAQ.md updates. + + Details: + - Added a frequently asked question to docs/FAQ.md regarding the + difference between upstream (vanilla) BLIS and AMD BLIS. + - Updated the name of ICES in the README.md to reflect the Oden + rebranding. + +commit a178a822ad3d5021489a0e61f909d8550ae12a8f +Author: Field G. Van Zee +Date: Wed Sep 30 16:00:52 2020 -0500 + + Added Zen2 links to docs/Performance.md Contents. + +commit 74ec6b8f457cabe37d2382aaab35ba04fc737948 +Author: Field G. Van Zee +Date: Wed Sep 30 15:54:18 2020 -0500 + + Added Epyc 7742 Zen2 ("Rome") performance results. + + Details: + - Added single-threaded and multithreaded performance results to + docs/Performance.md. These results were gathered on an Epyc 7742 + "Rome" server with AMD's Zen2 microarchitecture. Special thanks + to Jeff Diamond for facilitating access to the system via the + Oracle Cloud. + - Renamed files containing the previous Zen performance results for + consistency with the new results. + +commit bc4a213a2c3dcf8bbfcbb3a1ef3e9fc9e3226c34 +Author: Field G. Van Zee +Date: Wed Sep 30 15:28:20 2020 -0500 + + Updated matlab (now octave) plot code in test/3. + + Details: + - Renamed test/3/matlab to test/3/octave. + - Within test/3, updated and tuned plot_l3_perf.m and plot_panel_4x5.m + files for use with octave (which is free and doesn't crash on me + mid-way through my use of subplot). + - Updated runthese.m scratchpad for zen2 invocations. + - Added Nikolay S.'s subplot_tight() function, along with its license. + +commit c77ddc418187e1884fa6bcfe570eee295b9cb8bc +Author: Field G. Van Zee +Date: Wed Sep 30 20:15:43 2020 +0000 + + Added optional numactl usage to test/3/runme.sh. + +commit 2d8ec164e7ae4f0c461c27309dc1f5d1966eb003 +Author: Nicholai Tukanov +Date: Tue Sep 29 16:52:18 2020 -0500 + + Add POWER10 support to BLIS (#450) + +commit 4fd8d9fec2052257bf2a5c6e0d48ae619ff6c3e4 +Author: Field G. Van Zee +Date: Mon Sep 28 23:39:05 2020 +0000 + + Tweaked zen2 subconfig's MC cache blocksizes. + + Details: + - Updated the MC cache blocksizes registered by the 'zen2' subconfig. + - Minor updates to test/3/Makefile and test/3/runme.sh. + +commit 5efcdeffd58af621476d179afc0c19c0f912baa8 +Author: Field G. Van Zee +Date: Fri Sep 25 14:25:24 2020 -0500 + + More minor README.md updates. + +commit 9e940f8aad6f065ea1689e791b9a4e1fb7900c40 +Author: Field G. Van Zee +Date: Fri Sep 25 13:53:35 2020 -0500 + + Added 1m SISC bibtex to README.md. + + Details: + - Added final citation info to 1m bibtex in README.md file. + - Updated draft 1m paper link. + - Changed some http to https. + +commit e293cae2d1b9067261f613f25eaa0e871356b317 +Author: Field G. Van Zee +Date: Tue Sep 15 16:09:11 2020 -0500 + + Implemented sgemmsup assembly kernels. + + Details: + - Created a set of single-precision real millikernels and microkernels + comparable to the dgemmsup kernels that already exist within BLIS. + - Added prototypes for all kernels within bli_kernels_haswell.h. + - Registered entry-point millikernels in bli_cntx_init_haswell.c and + bli_cntx_init_zen.c. + - Added sgemmsup support to the Makefile, runme.sh script, and source + file in test/sup. This included edits that allow for separate "small" + dimensions for single- and double-precision as well as for single- + vs. multithreaded execution. + +commit 2765c6f37c11cb7f71cd4b81c64cea6130636c68 +Author: Field G. Van Zee +Date: Sat Sep 12 17:48:15 2020 -0500 + + Type saga continues; fixed sgemm ukernel signature. + + Details: + - Changed double* pointers in sgemm function signature to float*. At + this point I've lost track of whether this was my fault or another + dormant bug like the one described in ece9f6a, but at this point I + no longer care. It's one of those days (aka I didn't ask for this). + +commit 0779559509e0a1af077530d09ed151dac54f32ee +Author: Field G. Van Zee +Date: Sat Sep 12 17:37:21 2020 -0500 + + Fixed missing restrict in knl sgemm prototype. + + Details: + - Added a missing 'restrict' qualifier in the sgemm ukernel prototype + for knl. (Not sure how that code was ever compiling before now.) + +commit ece9f6a3ef1b26b53ecf968cd069df7a85b139fb +Author: Field G. Van Zee +Date: Sat Sep 12 17:22:42 2020 -0500 + + Fixed dormant type bugs in bli_kernels_knl.h. + + Details: + - Fixed dormant type mismatches in the use of the prototype-generating + macros in bli_kernels_knl.h. Specifically, some float prototypes + were incorrectly using double as their ctype. This didn't actually + matter until the type changes in 645d771, as previously those types + were not used since packm was prototyped with void* pointers. + +commit 8ebb3b60e1c4c045ddb48e02de6e246cecde24a4 +Author: Field G. Van Zee +Date: Sat Sep 12 17:00:47 2020 -0500 + + Fixed accidental breakage in 645d771. + + Details: + - In trying to clean up kappa_cast variables in the reference packm + kernels, which I initally believed to be redundant given the other + void* -> ctype* changes in 645d771, I accidentally ended up violating + restrict semantics for 1e/1r packing and possibly other packm kernels. + (Normally, my pre-commit testsuite run would have caught this, but I + was unknowingly using an edited input.operations file in which I'd + disabled most tests as part of unrelated work.) This commit reverts + the kappa_cast changes in 645d771. + +commit 645d771a14ae89aa7131d6f8f4f4a8090329d05e +Author: Field G. Van Zee +Date: Sat Sep 12 15:31:56 2020 -0500 + + Minor packm kernel type cleanup (void* -> ctype*). + + Details: + - Changed all void* function arguments in reference packm kernels to + those of the native type (ctype*). These pointers no longer need to + be void* and are better represented by their native types anyway. + (See below for details.) Updated knl packm kernels accordingly. + - In the definition of the PACKM_KER_PROT prototype macro template in + frame/1m/bli_l1m_ker_prot.h, changed the pointer types for kappa, a, + and p from void* to ctype*. They were originally void* because these + function signatures had to share the same type so they could all be + stored in a single array of that shared type, from which they were + queried and called by packm_cxk(). This is no longer how the function + pointers are stored, and so it no longer makes sense to force the + caller of packm kernels to use void*, only so that the implementor + of the packm kernels can typecast back to the native datatype within + the kernel definition. This change has no effect internally within + BLIS because currently all packm kernels are called after querying + the function addresses from the context and then typecasting to the + appropriate function pointer type, which is based upon type-specific + function pointers like float* and double*. + - Removed a comment in frame/1m/bli_l1m_ft_ker.h that was outdated and + misleading due to changes to the handling of packm kernels since + moving them into the context. + +commit 54bf6c35542a297e25bc8efec6067a6df80536f4 +Author: Field G. Van Zee +Date: Thu Sep 10 15:42:01 2020 -0500 + + Minor README.md update. + + Details: + - Added a new entry to the "What people are saying about BLIS" section. + +commit e50b4d40462714ae33df284655a2faf7fa35f37c +Author: Field G. Van Zee +Date: Wed Sep 9 14:12:53 2020 -0500 + + Minor update to README.md (SIAM Best Paper Prize). + +commit a8efb72074691e2610372108becd88b4b392299e +Merge: b0c4da17 97e87f2c +Author: Devin Matthews +Date: Mon Sep 7 16:18:19 2020 -0500 + + Merge pull request #434 from flame/intel-zdot + + Add an option to change the complex return type. + +commit 97e87f2c9f3878a05e1b7c6ec237ee88d9a72a42 +Author: Field G. Van Zee +Date: Mon Sep 7 15:56:42 2020 -0500 + + Whitespace/comment updates to #434 PR. + +commit b0c4da1732b6c6a9ff66f70c36e4722e0f9645ae +Merge: 810e90ee b1b5870d +Author: Devin Matthews +Date: Mon Sep 7 15:47:54 2020 -0500 + + Merge pull request #436 from flame/s390x + + Add checks so that s390x is detected as 64-bit. + +commit 810e90ee806510c57504f0cf8eeaf608d38bd9dd +Author: Field G. Van Zee +Date: Tue Sep 1 16:11:40 2020 -0500 + + Minor README.md update. + + Details: + - Added HPE to list of funders. + - Changed http to https in funders' website links. + +commit 7d411282196e036991c26e52cb5e5f85769c8059 +Author: Devin Matthews +Date: Thu Aug 13 17:50:58 2020 -0500 + + Use -O2 for all framework code. (#435) + + It seems that -O3 might be causing intermittent problems with the f2c'ed packed and banded code. -O3 is retained for kernel code. Fixes #341 and fixes #342. + +commit 9c5b485d356367b0a1288761cd623f52036e7344 +Author: Dave Love +Date: Fri Aug 7 20:11:18 2020 +0000 + + Don't override -mcpu with -march on ARM (#353) + + * Use -mcpu for ARM + See the GCC doc about -march, -mtune, and -mpu and maybe + https://community.arm.com/developer/tools-software/tools/b/tools-software-ides-blog/posts/compiler-flags-across-architectures-march-mtune-and-mcpu + + * Fix typo in flags + + * Fix typo in cortexa9 flags + + * Modify cortexa53 compilation flags to fix failing BLAS check (#341) + +commit c253d14a72a746b670b3ffbb6e81bcafc73d1133 +Author: Devin Matthews +Date: Fri Aug 7 09:39:04 2020 -0500 + + Also handle Intel-style complex return in CBLAS interface. + +commit 5d653a11a0cc71305d0995507b1733995856f475 +Author: Devin Matthews +Date: Thu Aug 6 17:58:26 2020 -0500 + + Update Multithreading.md + + Addresses the issue raised in #426. + +commit b1b5870dd3f9b1c78cf5f58a53514d73f001fc4c +Author: Devin Matthews +Date: Thu Aug 6 17:34:20 2020 -0500 + + Add checks so that s390x is detected as 64-bit. + +commit 882dcb11bfc9ea50aa2f9044621833efd90d42be +Author: Field G. Van Zee +Date: Thu Aug 6 17:28:14 2020 -0500 + + Mention example code at top of documentation docs. + + Details: + - Steer the reader towards the example code section of each + documentation doc (object and typed). + - Trivial update to examples/oapi/README, examples/tapi/README. + +commit f4894512e5bf56ff83701c07dd02972e300741a5 +Author: Field G. Van Zee +Date: Thu Aug 6 17:20:00 2020 -0500 + + Very minor updates to previous commit. + +commit adedb893ae8dfacd1dc54035979e15c44d589dbb +Author: Field G. Van Zee +Date: Thu Aug 6 17:14:01 2020 -0500 + + Documented mutator functions in BLISObjectAPI.md. + + Details: + - Added documentation for commonly-used object mutator functions in + BLISObjectAPI.md. Previously, only accessor functions were documented. + Thanks to Jeff Diamond for pointing out this omission. + - Explicitly set the 'diag' property of objects in oapi example modules + (08level2.c and 09level3.c). + +commit 5b5278ff494888509543a79c09ea82089f6c95d9 +Author: Devin Matthews +Date: Thu Aug 6 14:19:37 2020 -0500 + + Use #ifdef instead of #if as macro may be undefined. + +commit 7fdc0fc893d0c6727b725ea842053b65be2c20ba +Author: Devin Matthews +Date: Thu Aug 6 14:03:55 2020 -0500 + + Add an option to change the complex return type. + + ifort apparently does not return complex numbers in registers as in C/C++ (or gfortran), but instead creates a "hidden" first parameter for the return value. The option --complex-return=gnu|intel has been added, as well as a guess based on a provided FC if not specified (otherwise default to gnu). This option affects the signatures of cdotc, cdotu, zdotc, and zdotu, and a single library cannot be used with both GNU and Intel Fortran compilers. Fixes #433. + +commit 6e522e5823b762d4be09b6acdca30faafba56758 +Author: Field G. Van Zee +Date: Thu Jul 30 19:31:37 2020 -0500 + + Mention disabling of sup in docs/Sandboxes.md. + + Details: + - Added language to remind the reader to disable sup if the intended + behavior is for the sandbox implementation to handle all problem + sizes, even the smaller ones that would normally be handled by the + sup code path. + +commit 00e14cb6d849e963a2e1ac35e7dbbe186af00a58 +Author: Field G. Van Zee +Date: Wed Jul 29 14:24:34 2020 -0500 + + Replaced use of bool_t type with C99 bool. + + Details: + - Textually replaced nearly all non-comment instances of bool_t with the + C99 bool type. A few remaining instances, such as those in the files + bli_herk_x_ker_var2.c, bli_trmm_xx_ker_var2.c, and + bli_trsm_xx_ker_var2.c, were promoted to dim_t since they were being + used not for boolean purposes but to index into an array. + - This commit constitutes the third phase of a transition toward using + C99's bool instead of bool_t, which was raised in issue #420. The first + phase, which cleaned up various typecasts in preparation for using + bool as the basis for bool_t (instead of gint_t), was implemented by + commit a69a4d7. The second phase, which redefined the bool_t typedef + in terms of bool (from gint_t), was implemented by commit 2c554c2. + +commit 2c554c2fce885f965a425e727a0314d3ba66c06d +Author: Field G. Van Zee +Date: Fri Jul 24 15:57:19 2020 -0500 + + Redefined bool_t typedef in terms of C99 bool. + + Details: + - Changed the typedef that defines bool_t from: + + typedef gint_t bool_t; + + where gint_t is a signed integer that forms the basis of most other + integers in BLIS, to: + + typedef bool bool_t; + + - Changed BLIS's TRUE and FALSE macro definitions from being in terms of + integer literals: + + #define TRUE 1 + #define FALSE 0 + + to being in terms of C99 boolean constants: + + #define TRUE true + #define FALSE false + + which are provided by stdbool.h. + - This commit constitutes the second phase of a transition toward using + C99's bool instead of bool_t, which will address issue #420. The first + phase, which cleaned up various typecasts in preparation for using + bool as the basis for bool_t (instead of gint_t), was implemented by + commit a69a4d7. + +commit e01dd125581cec87f61e15590922de0dc938ec42 +Author: Field G. Van Zee +Date: Fri Jul 24 15:41:46 2020 -0500 + + Fail-safe updates to Makefiles in 'test' dir. + + Details: + - Updated Makefiles in test, test/3, and test/sup so that running any of + the usual targets without having first built BLIS results in a helpful + error message. For example, if BLIS is not yet configured, make will + output: + + Makefile:327: *** Cannot proceed: config.mk not detected! Run + configure first. Stop. + + Similarly, if BLIS is configured but not yet built, make will output: + + Makefile:340: *** Cannot proceed: BLIS library not yet built! Run + make first. Stop. + + In previous commits, these actions would result in a rather cryptic + make error such as: + + make: *** No rule to make target 'test_sgemm_2400_asm_blis_st.x', + needed by 'blis-nat-st'. Stop. + +commit b4f47f7540062da3463e2cb91083c12fdda0d30a +Author: Devin Matthews +Date: Fri Jul 24 13:56:13 2020 -0500 + + Add BLIS_EXPORT_BLIS to bli_abort. (#429) + + Fixes #428. + +commit a69a4d7e2f4607c919db30b14535234ce169c789 +Author: Field G. Van Zee +Date: Wed Jul 22 16:13:09 2020 -0500 + + Cleaned up bool_t usage and various typecasts. + + Details: + - Fixed various typecasts in + + frame/base/bli_cntx.h + frame/base/bli_mbool.h + frame/base/bli_rntm.h + frame/include/bli_misc_macro_defs.h + frame/include/bli_obj_macro_defs.h + frame/include/bli_param_macro_defs.h + + that were missing or being done improperly/incompletely. For example, + many return values were being typecast as + (bool_t)x && y + rather than + (bool_t)(x && y) + Thankfully, none of these deficiencies had manifested as actual bugs + at the time of this commit. + - Changed the return type of bli_env_get_var() from dim_t to gint_t. + This reflects the fact that bli_env_get_var() needs to be able to + return a signed integer, and even though dim_t is currently defined + as a signed integer, it does not intuitively appear to necessarily be + signed by inspection (i.e., an integer named "dim_t" for matrix + "dimension"). Also, updated use of bli_env_get_var() within + bli_pack.c to reflect the changed return type. + - Redefined type of thrcomm_t.barrier_sense field from bool_t to gint_t + and added comments to the bli_thrcomm_*.h files that will explain a + planned replacement of bool_t with C99's bool type. + - Note: These changes are being made to facilitate the substitution of + 'bool' for 'bool_t', which will eliminate the namespace conflict with + arm_sve.h as reported in issue #420. This commit implements the first + phase of that transition. Thanks to RuQing Xu for reporting this + issue. + - CREDITS file update. + +commit a6437a5c11d364c6c88af527294d29734d7cc7d6 +Author: Field G. Van Zee +Date: Mon Jul 20 19:21:07 2020 -0500 + + Replaced broken ref99 sandbox w/ simpler version. + + Details: + - The 'ref99' sandbox was broken by multiple refactorings and internal + API changes over the last two years. Rather than try to fix it, I've + replaced it with a much simpler version based on var2 of gemmsup. + Why not fix the previous implementation? It occurred to me that the + old implementation was trying to be a lightly simplified duplication + of what exists in the framework. Duplication aside, this sandbox + would have worked fine if it had been completely independent of the + framework code. The problem was that it was only partially + independent, with many function calls calling a function in BLIS + rather than a duplicated/simplified version within the sandbox. (And + the reason I didn't make it fully independent to begin with was that + it seemed unnecessarily duplicative at the time.) Maintaining two + versions of the same implementation is problematic for obvious + reasons, especially when it wasn't even done properly to begin with. + This explains the reimplementation in this commit. The only catch is + that the newer implementation is single-threaded only and does not + perform any packing on either input matrix (A or B). Basically, it's + only meant to be a simple placeholder that shows how you could plug + in your own implementation. Thanks to Francisco Igual for reporting + this brokenness. + - Updated the three reference gemmsup kernels (defined in + ref_kernels/3/bli_gemmsup_ref.c) so that they properly handle + conjugation of conja and/or conjb. The general storage kernel, which + is currently identical to the column-storage kernel, is used in the + new ref99 sandbox to provide basic support for all datatypes + (including scomplex and dcomplex). + - Minor updates to docs/Sandboxes.md, including adding the threading + and packing limitations to the Caveats section. + - Fixed a comment typo in bli_l3_sup_var1n2m.c (upon which the new + sandbox implementation is based). + +commit bca040be9da542dd9c75d91890fa7731841d733d +Merge: 2605eb4d 171ecc1d +Author: Devin Matthews +Date: Mon Jul 20 09:27:30 2020 -0500 + + Merge pull request #425 from gmargari/patch-1 + + Update Multithreading.md + +commit 171ecc1dc6f055ea39da30e508f711b49a734359 +Author: Giorgos Margaritis +Date: Mon Jul 20 12:24:06 2020 +0300 + + Update Multithreading.md + +commit 2605eb4d99d3813c37a624c011aa2459324a6d89 +Author: Field G. Van Zee +Date: Wed Jul 15 15:25:19 2020 -0500 + + Added missing rv_d?x6 edge cases to sup kernel. + + Details: + - Added support to bli_gemmsup_rv_haswell_asm_d6x8n.c for handling + various n = 6 edge cases with a single sup kernel call. Previously, + only n = {4,2,1} were handled explicitly as single kernel calls; + that is, cases where n = 6 were previously being executed via two + kernel calls (n = 4 and n = 2). + - Added commented debug line to testsuite's test_libblis.c. + +commit 72f6ed0637dfcb021de04ac7d214d5c87e55d799 +Author: Field G. Van Zee +Date: Fri Jul 3 17:55:54 2020 -0500 + + Declare/define static functions via BLIS_INLINE. + + Details: + - Updated all static function definitions to use the cpp macro + BLIS_INLINE instead of the static keyword. This allows blis.h to + use a different keyword (inline) to define these functions when + compiling with C++, which might otherwise trigger "defined but + not used" warning messages. Thanks to Giorgos Margaritis for + reporting this issue and Devin Matthews for suggesting the fix. + - Updated the following files, which are used by configure's + hardware auto-detection facility, to unconditionally #define + BLIS_INLINE to the static keyword (since we know BLIS will be + compiled with C, not C++): + build/detect/config/config_detect.c + frame/base/bli_arch.c + frame/base/bli_cpuid.c + - CREDITS file update. + +commit 5fc701ac5f94c6300febbb2f24e731aa34f0f34a +Author: Field G. Van Zee +Date: Wed Jul 1 15:48:58 2020 -0500 + + Added -fomit-frame-pointer option to CKOPTFLAGS. + + Details: + - Added the -fomit-frame-pointer compiler option to the CKOPTFLAGS + variable in the following make_defs.mk files: + config/haswell/make_defs.mk + config/skx/make_defs.mk + as well as comments that mention why the compiler option is needed. + This option is needed to prevent the compiler from using the rbp + frame register (in the very early portion of kernel code, typically + where k_iter and k_left are defined and computed), which, as of + 1c719c9, is used explicitly by the gemmsup millikernels. Thanks to + Devin Matthews for identifying this missing option and to Jeff + Diamond for reporting the original bug in #417. + - The file + config/zen/amd_config.mk + which feeds into the make_defs.mk for both zen and zen2 subconfigs, + was also touched, but only to add a commented-out compiler option + (and the aforementioned explanatory comment) since that file already + uses -fomit-frame-pointer in COPTFLAGS, which forms the basis of + CKOPTFLAGS. + +commit 6af59b705782dada47e45df6634b479fe781d4fe +Author: Field G. Van Zee +Date: Wed Jul 1 14:54:23 2020 -0500 + + Fixed disabled edge case optimization in gemmsup. + + Details: + - Fixed an inadvertently disabled edge case optimization in the two + gemmsup variants in bli_l3_sup_var1n2m.c. Background: These edge case + optimizations allow the last millikernel operation in the jr loop to + be executed with inflated an register blocksize if it is the last + (or only) iteration. For example, if mr=6 and nr=8 and the gemmsup + problem is m=8, n=100, k=100. (In this case, the panel-block variant + (var1n) is executed, which places the jr loop in the m dimension.) + In principle, this problem could be executed as two millikernels: one + with dimensions 6x100x100, and one as 2x100x100. However, with the + support for inflated blocksizes in the kernel, the entire 8x100x100 + problem can be passed to the millikernel function, which will then + execute it more favorably as two 4x100x100 millikernel sub-calls. + Now, this optimization is disabled under certain circumstances, such + as when multithreading. Previously, the is_mt predicate was being set + incorrectly such that it was non-zero even when running + single-threaded. + - Upon fixing the is_mt issue above, another bit of code needed to be + moved so that the result of the optimization could have an impact on + the assignment of loop bounds ranges to threads. + +commit b37634540fab0f9b8d4751b8356ee2e17c9e3b00 +Author: Field G. Van Zee +Date: Thu Jun 25 16:05:12 2020 -0500 + + Support ldims, packing in sup/test drivers. + + Details: + - Updated the test/sup source file (test_gemm.c) and Makefile to support + building matrices with small or large leading dimensions, and updated + runme.sh to support executing both kinds of test drivers. + - Updated runme.sh to allow for executing sup drivers with unpacked (the + default) or packed matrices (via setting BLIS_PACK_A, BLIS_PACK_B + environment variables), and for capturing output to files that encode + both the leading dimension (small or large) and packing status into + the filenames. + - Consolidated octave scripts in test/sup/octave_st, test/sup/octave_mt + into test/sup/octave and updated the octave code in that consolidated + directory to read the new output filename format (encoding ldim and + packing). Also added comments and streamlined code, particularly in + plot_panel_trxsh.m. Tested the octave scripts with octave 5.2.0. + - Moved old octave_st, octave_mt directories to test/sup/old. + +commit ceb9b95a96cc3844ecb43d9af48ab289584e76b6 +Author: Field G. Van Zee +Date: Thu Jun 18 17:15:25 2020 -0500 + + Fixed incorrect link to shiftd in BLISTypedAPI.md. + + Details: + - Previously, the entry for shiftd in the Operation index section of + BLISTypedAPI.md was incorrectly linking to the shiftd operation entry + in BLISObjectAPI.md. This has been fixed. Thanks to Jeff Diamond for + helping find this incorrect link. + +commit b3c42016818797f79e55b32c8b7d090f9d0aa0ea +Author: Field G. Van Zee +Date: Thu Jun 18 14:00:56 2020 -0500 + + CREDITS file update. + +commit 31af73c11abae03248d959da0f81eacea015b57a +Author: Isuru Fernando +Date: Thu Jun 18 13:35:54 2020 -0500 + + Expand windows instructions (#414) + + * Expand windows instructions + + * Windows: both static and shared don't work at the same time + +commit b5b604e106076028279e6d94dc0e51b8ad48e802 +Author: Field G. Van Zee +Date: Wed Jun 17 16:42:24 2020 -0500 + + Ensure random objects' 1-norms are non-zero. + + Details: + - Fixed an innocuous bug that manifested when running the testsuite on + extremely small matrices with randomization via the "powers of 2 in + narrow precision range" option enabled. When the randomization + function emits a perfect 0.0 to fill a 1x1 matrix, the testsuite will + then compute 0.0/0.0 during the normalization process, which leads to + NaN residuals. The solution entails smarter implementaions of randv, + randnv, randm, and randnm, each of which will compute the 1-norm of + the vector or matrix in question. If the object has a 1-norm of 0.0, + the object is re-randomized until the 1-norm is not 0.0. Thanks to + Kiran Varaganti for reporting this issue (#413). + - Updated the implementation of randm_unb_var1() so that it loops over + a call to the randv_unb_var1() implementation directly rather than + calling it indirectly via randv(). This was done to avoid the overhead + of multiple calls to norm1v() when randomizing the rows/columns of a + matrix. + - Updated comments. + +commit 35e38fb693e7cbf2f3d7e0505a63b2c05d3f158d +Author: Isuru Fernando +Date: Tue Jun 16 10:59:41 2020 -0500 + + FIx typo in FAQ + +commit 1c719c91a3ef0be29a918097652beef35647d4b2 +Author: Field G. Van Zee +Date: Thu Jun 4 17:21:08 2020 -0500 + + Bugfixes, cleanup of sup dgemm ukernels. + + Details: + - Fixed a few not-really-bugs: + - Previously, the d6x8m kernels were still prefetching the next upanel + of A using MR*rs_a instead of ps_a (same for prefetching of next + upanel of B in d6x8n kernels using NR*cs_b instead of ps_b). Given + that the upanels might be packed, using ps_a or ps_b is the correct + way to compute the prefetch address. + - Fixed an obscure bug in the rd_d6x8m kernel that, by dumb luck, + executed as intended even though it was based on a faulty pointer + management. Basically, in the rd_d6x8m kernel, the pointer for B + (stored in rdx) was loaded only once, outside of the jj loop, and in + the second iteration its new position was calculated by incrementing + rdx by the *absolute* offset (four columns), which happened to be the + same as the relative offset (also four columns) that was needed. It + worked only because that loop only executed twice. A similar issue + was fixed in the rd_d6x8n kernels. + - Various cleanups and additions, including: + - Factored out the loading of rs_c into rdi in rd_d6x8[mn] kernels so + that it is loaded only once outside of the loops rather than + multiple times inside the loops. + - Changed outer loop in rd kernels so that the jump/comparison and + loop bounds more closely mimic what you'd see in higher-level source + code. That is, something like: + for( i = 0; i < 6; i+=3 ) + rather than something like: + for( i = 0; i <= 3; i+=3 ) + - Switched row-based IO to use byte offsets instead of byte column + strides (e.g. via rsi register), which were known to be 8 anyway + since otherwise that conditional branch wouldn't have executed. + - Cleaned up and homogenized prefetching a bit. + - Updated the comments that show the before and after of the + in-register transpositions. + - Added comments to column-based IO cases to indicate which columns + are being accessed/updated. + - Added rbp register to clobber lists. + - Removed some dead (commented out) code. + - Fixed some copy-paste typos in comments in the rv_6x8n kernels. + - Cleaned up whitespace (including leading ws -> tabs). + - Moved edge case (non-milli) kernels to their own directory, d6x8, + and split them into separate files based on the "NR" value of the + kernels (Mx8, Mx4, Mx2, etc.). + - Moved config-specific reference Mx1 kernels into their own file + (e.g. bli_gemmsup_r_haswell_ref_dMx1.c) inside the d6x8 directory. + - Added rd_dMx1 assembly kernels, which seems marginally faster than + the corresponding reference kernels. + - Updated comments in ref_kernels/bli_cntx_ref.c and changed to using + the row-oriented reference kernels for all storage combos. + +commit 943a21def0bedc1732c0a2453afe7c90d7f62e95 +Author: Isuru Fernando +Date: Thu May 21 14:09:21 2020 -0500 + + Add build instructions for Windows (#404) + +commit fbef422f0d968df10e598668b427af230cfe07e8 +Author: Field G. Van Zee +Date: Thu May 21 10:30:41 2020 -0500 + + Separate OS X and Windows into separate FAQs. + + Details: + - Separated the unified Mac OS X / Windows frequently asked question + into two separate questions, one for each OS. + +commit 28be1a4265ea67e3f177c391aba3dbbcf840bd52 +Author: Guodong Xu +Date: Thu May 21 02:22:22 2020 +0800 + + avoid loading twice in armv8a gemm kernel (#403) + + This bug happens at a corner case, when k_iter == 0 and we jump to + CONSIDERKLEFT. + + In current design, first row/col. of a and b are loaded twice. + + The fix is to rearrange a and b (first row/col.) loading instructions. + + Signed-off-by: Guodong Xu + +commit d51245e58b0beff2717156b980007c90337150d8 +Author: Field G. Van Zee +Date: Fri May 8 18:00:54 2020 -0500 + + Add support for Intel oneAPI in configure. + + Details: + - Properly select cc_vendor based on the output of invoking CC with the + --version option, including cases where CC is the variant of clang + that is included with Intel oneAPI. (However, we continue to treat + the compiler as clang for other purposes, not icc.) Thanks to Ajay + Panyala and Devin Matthews for reporting on this issue via #402. + +commit 787adad73bd5eb65c12c39d732723a1ac0448748 +Author: Field G. Van Zee +Date: Fri May 8 16:18:20 2020 -0500 + + Defined netlib equivalent of xerbla_array(). + + Details: + - Added a function definition for xerbla_array_(), which largely mirrors + its netlib implementation. Thanks to Isuru Fernando for suggesting the + addition of this function. + +commit c53b5153bee585685bf95ce22e058a7af72ecef0 +Author: Field G. Van Zee +Date: Tue May 5 12:39:12 2020 -0500 + + Documented Perl prerequisite for build system. + + Details: + - Added Perl to list of prerequisites for building BLIS. This is in part + (and perhaps completely?) due to some substitution commands used at + the end of configure that include '\n' characters that are not + properly interpreted by the version of sed included on some versions + of OS X. This new documentation addresses issue #398. + +commit f032d5d4a6ed34c8c3e5ba1ed0b14d1956d0097c +Author: Guodong Xu +Date: Thu Apr 30 01:08:46 2020 +0800 + + New kernel set for Arm SVE using assembly (#396) + + Here adds two kernels for Arm SVE vector extensions. + 1. a gemm kernel for double at sizes 8x8. + 2. a packm kernel for double at dimension 8xk. + + To achive best performance, variable length agonostic programming + is not used. Vector length (VL) of 256 bits is mandated in both kernels. + Kernels to support other VLs can be added later. + + "SVE is a vector extension for AArch64 execution mode for the A64 + instruction set of the Armv8 architecture. Unlike other SIMD architectures, + SVE does not define the size of the vector registers, but constrains into + a range of possible values, from a minimum of 128 bits up to a maximum of + 2048 in 128-bit wide units. Therefore, any CPU vendor can implement the + extension by choosing the vector register size that better suits the + workloads the CPU is targeting. Instructions are provided specifically + to query an implementation for its register size, to guarantee that + the applications can run on different implementations of the ISA without + the need to recompile the code." [1] + + [1] https://developer.arm.com/solutions/hpc/resources/hpc-white-papers/arm-scalable-vector-extensions-and-application-to-machine-learning + + Signed-off-by: Guodong Xu + +commit 4d87eb24e8e1f5a21e04586f6df4f427bae0091b +Author: Yingbo Ma +Date: Mon Apr 27 17:02:47 2020 -0400 + + Update KernelsHowTo.md (#395) + +commit 477ce91c5281df2bbfaddc4d86312fb8c8f879e2 +Author: Field G. Van Zee +Date: Wed Apr 22 14:26:49 2020 -0500 + + Moved #include "cpuid.h" to bli_cpuid.c. + + Details: + - Relocated the #include "cpuid.h" directive from bli_cpuid.h to + bli_cpuid.c. This was done because cpuid.h (which is pulled into + the post-build blis.h developer header) doesn't protect its + definitions with a preprocessor guard of the form: + + #ifndef FOOBAR_H + #define FOOBAR_H + // header contents. + #endif + + and as a result, applications (previously) could not #include both + blis.h and cpuid.h (since the former was already including the + latter). Thanks to Bhaskar Nallani for raising this issue via #393 + and to Devin Matthews for suggesting this fix. + - CREDITS file update. + +commit 8bde63ffd7474a97c3a3b0b0dc1eae45be0ab889 +Author: Field G. Van Zee +Date: Sat Apr 18 12:50:12 2020 -0500 + + Adding missing conjy to her2/syr2 in typed API doc. + + Details: + - Fixed a missing argument (conjy) in the function signatures of + bli_?her2() and bli_?syr2() in docs/BLISTypedAPI.md. Thanks to Robert + van de Geijn for reporting this omission. + +commit 976902406b610afdbacb2d80a7a2b4b43ff30321 +Author: Field G. Van Zee +Date: Fri Apr 17 15:11:10 2020 -0500 + + Disable packing by default in expert rntm_t init. + + Details: + - Changed the behavior of bli_rntm_init() as well as the static + initializer, BLIS_RNTM_INITIALIZER, so that user-initialized rntm_t + objects by default specify the disabling of packing for A and B. + Packing of A/B was already disabled by default when calling non-expert + APIs (and enabled only when the user set environment variables + BLIS_PACK_A or BLIS_PACK_B). With this commit, the default behavior of + using user-initialized rntm_t objects with expert APIs comes into line + with the default behavior of non-expert APIs--that is, they now both + lead to the avoidance of packing in the sup code path. (Note: The + conventional code path is unaffected by the environment variables + BLIS_PACK_A/BLIS_PACK_B and/or the disabling of packing in a rntm_t + object when calling an expert API.) This addresses issue #392. Thanks + to Kiran Varaganti for bringing this inconsistency to our attention. + - The above change was accomplished by changing the the definitions of + static functions bli_rntm_clear_pack_a() and bli_rntm_clear_pack_b() + in bli_rntm.h, which are both for internal use only. + +commit 5f2aee7c5fa5d562acaf8fbde3df0e2a04e1dd1b +Author: Field G. Van Zee +Date: Tue Apr 7 14:55:15 2020 -0500 + + README.md update to promote supmt dgemm. + + Details: + - Updated the sup entry in the "What's New" section of the README.md + file to promote the multithreaded dgemm sup feature introduced in + c0558fd. + +commit f5923cd9ff5fbd91190277dea8e52027174a1d57 +Author: Field G. Van Zee +Date: Tue Apr 7 14:41:45 2020 -0500 + + CHANGELOG update (0.7.0) + +commit 68b88aca6692c75a9f686187e6c4a4e196ae60a9 (tag: 0.7.0) +Author: Field G. Van Zee +Date: Tue Apr 7 14:41:44 2020 -0500 + + Version file update (0.7.0) + +commit b04de636c1702e4cb8e7ad82bab3cf43d2dbdfc6 +Author: Field G. Van Zee +Date: Tue Apr 7 14:37:43 2020 -0500 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated docs/ReleaseNotes.md in preparation for next version. + +commit 2cb604ba472049ad498df72d4a2dc47a161d4c3c +Author: Field G. Van Zee +Date: Mon Apr 6 16:42:14 2020 -0500 + + Rename more bli_thread_obarrier(), _obroadcast(). + + Details: + - Renamed instances of bli_thread_obarrier() and bli_thread_obroadcast() + that were made in the supmt-specific code commited to the 'amd' + branch, which has now been merged with 'master'. Prior to the merge, + 'master' received commit c01d249, which applied these renamings to + the existing, non-sup codebase. + +commit efb12bc895de451067649d5dceb059b7827a025f +Author: Field G. Van Zee +Date: Mon Apr 6 15:01:53 2020 -0500 + + Minor updates/elaborations to RELEASING file. + +commit 2e3b3782cfb7a2fd0d1a325844983639756def7d +Merge: 9f3a8d4d da0c086f +Author: Field G. Van Zee +Date: Mon Apr 6 14:55:35 2020 -0500 + + Merge branch 'master' into amd + +commit da0c086f4643772e111318f95a712831b0f981a8 +Author: Satish Balay +Date: Tue Mar 31 17:09:41 2020 -0500 + + OSX: specify the full path to the location of libblis.dylib (#390) + + * OSX: specify the full path to the location of libblis.dylib so that it can be found at runtime + + Before this change: + + Appication gives runtime error [when linked with blis] + dyld: Library not loaded: libblis.3.dylib + + balay@kpro lib % otool -L libblis.dylib + libblis.dylib: + libblis.3.dylib (compatibility version 0.0.0, current version 0.0.0) + /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1281.0.0) + + After this change: + balay@kpro lib % otool -L libblis.dylib + libblis.dylib: + /Users/balay/petsc/arch-darwin-c-debug/lib/libblis.3.dylib (compatibility version 0.0.0, current version 0.0.0) + /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1281.0.0) + + * INSTALL_LIBDIR -> libdir as INSTALL_LIBDIR has DESTDIR + + Co-Authored-By: Jed Brown + + * CREDITS file update. + + Co-authored-by: Jed Brown + Co-authored-by: Field G. Van Zee + +commit 2bca03ea9d87c0da829031a5332545d05e352211 +Author: Field G. Van Zee +Date: Sat Mar 28 22:10:00 2020 +0000 + + Updates, tweaks to runme.sh in test/1m4m. + + Details: + - Made several updates to test/1m4m/runme.sh, including: + - Added missing handling for 1m and 4m1a implementations when setting + the BLIS_??_NT environment variables. + - Added support for using numactl to run the test executables. + - Several other cleanups. + +commit c40a33190b94af5d5c201be63366594859b1233f +Author: Field G. Van Zee +Date: Thu Mar 26 16:55:00 2020 -0500 + + Warn user when auto-detection returns 'generic'. + + Details: + - Added logic to configure that causes the script to output a warning + to the user if/when "./configure auto" is run and the underlying + hardware feature detection code is unable to identify the hardware. + In these cases, the auto-detect code will return 'generic', which + is likely not what the user expected, and a flag will be set so that + a message is printed at the end of the configure output. (Thankfully, + we don't expect this scenario to play out very often.) Thanks to + Devin Matthews for suggesting this fix #384. + +commit 492a736fab5b9c882996ca024b64646877f22a89 +Author: Devin Matthews +Date: Tue Mar 24 17:28:47 2020 -0500 + + Fix vectorized version of bli_amaxv (#382) + + * Fix vectorized version of bli_amaxv + + To match Netlib, i?amax should return: + - the lowest index among equal values + - the first NaN if one is encountered + + * Fix typos. + + * And another one... + + * Update ref. amaxv kernel too. + + * Re-enabled optimized amaxv kernels. + + Details: + - Re-enabled the optimized, intrinsics-based amaxv kernels in the 'zen' + kernel set for use in haswell, zen, zen2, knl, and skx subconfigs. + These two kernels (for s and d datatypes) were temporarily disabled in + e186d71 as part of issue #380. However, the key missing semantic + properties that prompted the disabling of these kernels--returning the + index of the *first* rather than of the last element with largest + absolute value, and returning the index of the first NaN if one is + encountered--were added as part of #382 thanks to Devin Matthews. + Thus, now that the kernels are working as expected once more, this + commit causes these kernels to once again be registered for the + affected subconfigs, which effectively reverts all code changes + included in e186d71. + - Whitespace/formatting updates to new macros in bli_amaxv_zen_int.c. + + Co-authored-by: Field G. Van Zee + +commit e186d7141a51f2d7196c580e24e7b7db8f209db9 +Author: Field G. Van Zee +Date: Sat Mar 21 18:40:36 2020 -0500 + + Disabled optimized amaxv kernels. + + Details: + - Disabled use of optimized amaxv kernels, which use vector intrinsics + for both 's' and 'd' datatypes. We disable these kernels because the + current implementations fail to observe a semantic property of the + BLAS i?amax_() subroutine, which is to return the index of the + *first* element containing the maximum absolute value (that is, the + first element if there exist two or more elements that contain the + same value). With the optimized kernels disabled, the affected + subconfigurations (haswell, zen, zen2, knl, and skx) will use the + default reference implementations. Thanks to Mat Cross for reporting + this issue via #380. + - CREDITS file update. + +commit 9f3a8d4d851725436b617297231a417aa9ce8c6a +Author: Field G. Van Zee +Date: Sat Mar 14 17:48:43 2020 -0500 + + Added missing return to bli_thread_partition_2x2(). + + Details: + - Added a missing return statement to the body of an early case handling + branch in bli_thread_partition_2x2(). This bug only affected cases + where n_threads < 4, and even then, the code meant to handle cases + where n_threads >= 4 executes and does the right thing, albeit using + more CPU cycles than needed. Nonetheless, thanks to Kiran Varaganti + for reporting this bug via issue #377. + - Whitespace changes to bli_thread.c (spaces -> tabs). + +commit 8c3d9b9eeb6f816ec8c32a944f632a5ad3637593 +Merge: 71249fe8 0f9e0399 +Author: Field G. Van Zee +Date: Tue Mar 10 14:03:33 2020 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 71249fe8ddaa772616698f1e3814d40e012909ea +Author: Field G. Van Zee +Date: Tue Mar 10 13:55:29 2020 -0500 + + Merged test/sup, test/supmt into test/sup. + + Details: + - Updated the Makefile, test_gemm.c, and runme.sh in test/sup to be able + to compile and run both single-threaded and multithreaded experiments. + This should help with maintenance going forward. + - Created a test/sup/octave_st directory of scripts (based on the + previous test/sup/octave scripts) as well as a test/sup/octave_mt + directory (based on the previous test/supmt/octave scripts). The + octave scripts are slightly different and not easily mergeable, and + thus for now I'll maintain them separately. + - Preserved the previous test/sup directory as test/sup/old/supst and + the previous test/supmt directory as test/sup/old/supmt. + +commit 0f9e0399e16e96da2620faf2c0c3c21274bb2ebd +Author: Field G. Van Zee +Date: Thu Mar 5 17:03:21 2020 -0600 + + Updated sup performance graphs; added mt results. + + Details: + - Reran all existing single-threaded performance experiments comparing + BLIS sup to other implementations (including the conventional code + path within BLIS), using the latest versions (where appropriate). + - Added multithreaded results for the three existing hardware types + showcased in docs/PerformanceSmall.md: Kaby Lake, Haswell, and Epyc + (Zen1). + - Various minor updates to the text in docs/PerformanceSmall.md. + - Updates to the octave scripts in test/sup/octave, test/supmt/octave. + +commit 90db88e5729732628c1f3acc96eeefab49f2da41 +Author: Field G. Van Zee +Date: Mon Mar 2 15:06:48 2020 -0600 + + Updated sup[mt] Makefiles for variable dim ranges. + + Details: + - Updated test/sup/Makefile and test/supmt/Makefile to allow specifying + different problem size ranges for the drivers where one, two, or three + matrix dimensions is large. This will facilitate the generation of + more meaningful graphs, particularly when two dimensions are tiny. + +commit 31f11a06ea9501724feec0d2fc5e4644d7dd34fc +Author: Field G. Van Zee +Date: Thu Feb 27 14:33:20 2020 -0600 + + Updates to octave scripts in test/sup[mt]/octave. + + Details: + - Optimized scripts in test/sup/octave and test/supmt/octave for use + with octave 5.2.0 on Ubuntu 18.04. + - Fixed stray 'end' keywords in gen_opsupnames.m and plot_l3sup_perf.m, + which were not only unnecessary but also causing issues with versions + 5.x. + +commit c01d249d7c546fe2e3cee3fe071cd4c4c88b9115 +Author: Field G. Van Zee +Date: Tue Feb 25 14:50:53 2020 -0600 + + Renamed bli_thread_obarrier(), _obroadcast(). + + Details: + - Renamed two bli_thread_*() APIs: + bli_thread_obarrier() -> bli_thread_barrier() + bli_thread_obroadcast() -> bli_thread_broadcast() + The 'o' was a leftover from when thrcomm_t objects tracked both + "inner" and "outer" communicators. They have long since been + simplified to only support the latter, and thus the 'o' is + superfluous. + +commit f6e6bf73e695226c8b23fe7900da0e0ef37030c1 +Author: Field G. Van Zee +Date: Mon Feb 24 17:52:23 2020 -0600 + + List Gentoo under supported external packages. + + Details: + - Add mention of Gentoo Linux under the list of external packages in + the README.md file. Thanks to M. Zhou for maintaining this package. + +commit 9e5f7296ccf9b3f7b7041fe1df20b927cd0e914b +Author: Field G. Van Zee +Date: Tue Feb 18 15:16:03 2020 -0600 + + Skip building thrinfo_t tree when mt is disabled. + + Details: + - Return early from bli_thrinfo_sup_grow() if the thrinfo_t object + address is equal to either &BLIS_GEMM_SINGLE_THREADED or + &BLIS_PACKM_SINGLE_THREADED. + - Added preprocessor logic to bli_l3_sup_thread_decorator() in + bli_l3_sup_decor_single.c that (by default) disables code that + creates and frees the thrinfo_t tree and instead passes + &BLIS_GEMM_SINGLE_THREADED as the thrinfo_t pointer into the + sup implementation. + - The net effect of the above changes is that a small amount of + thrinfo_t overhead is avoided when running small/skinny dgemm + problems when BLIS is compiled with multithreading disabled. + +commit 90081e6a64b5ccea9211bdef193c2d332c68492f +Author: Field G. Van Zee +Date: Mon Feb 17 14:57:25 2020 -0600 + + Fixed bug(s) in mt sup when single-threaded. + + Details: + - Fixed a syntax bug in bli_l3_sup_decor_single.c as a result of + changing function interface for the thread entry point function + (of type l3supint_t). + - Unfortunately, fixing the interface was not enough, as it caused + a memory leak in the sba at bli_finalize() time. It turns out that, + due to the new multithreading-capable variant code useing thrinfo_t + objects--specifically, their calling of bli_thrinfo_grow()--we + have to pass in a real thrinfo_t object rather than the global + objects &BLIS_PACKM_SINGLE_THREADED or &BLIS_GEMM_SINGLE_THREADED. + Thus, I inserted the appropriate logic from the OpenMP and pthreads + versions so that single-threaded execution would work as intended + with the newly upgraded variants. + +commit c0558fde4511557c8f08867b035ee57dd2669dc6 +Author: Field G. Van Zee +Date: Mon Feb 17 14:08:08 2020 -0600 + + Support multithreading within the sup framework. + + Details: + - Added multithreading support to the sup framework (via either OpenMP + or pthreads). Both variants 1n and 2m now have the appropriate + threading infrastructure, including data partitioning logic, to + parallelize computation. This support handles all four combinations + of packing on matrices A and B (neither, A only, B only, or both). + This implementation tries to be a little smarter when automatic + threading is requested (e.g. via BLIS_NUM_THREADS) in that it will + recalculate the factorization in units of micropanels (rather than + using the raw dimensions) in bli_l3_sup_int.c, when the final + problem shape is known and after threads have already been spawned. + - Implemented bli_?packm_sup_var2(), which packs to conventional row- + or column-stored matrices. (This is used for the rrc and crc storage + cases.) Previously, copym was used, but that would no longer suffice + because it could not be parallelized. + - Minor reorganization of packing-related sup functions. Specifically, + bli_packm_sup_init_mem_[ab]() are called from within packm_sup_[ab]() + instead of from the variant functions. This has the effect of making + the variant functions more readable. + - Added additional bli_thrinfo_set_*() static functions to bli_thrinfo.h + and inserted usage of these functions within bli_thrinfo_init(), which + previously was accessing thrinfo_t fields via the -> operator. + - Renamed bli_partition_2x2() to bli_thread_partition_2x2(). + - Added an auto_factor field to the rntm_t struct in order to track + whether automatic thread factorization was originally requested. + - Added new test drivers in test/supmt that perform multithreaded sup + tests, as well as appropriate octave/matlab scripts to plot the + resulting output files. + - Added additional language to docs/Multithreading.md to make it clear + that specifying any BLIS_*_NT variable, even if it is set to 1, will + be considered manual specification for the purposes of determining + whether to auto-factorize via BLIS_NUM_THREADS. + - Minor comment updates. + +commit d7a7679182d72a7eaecef4cd9b9a103ee0a7b42b +Author: Field G. Van Zee +Date: Fri Feb 7 17:37:03 2020 -0600 + + Fixed int-to-packbuf_t conversion error (C++ only). + + Details: + - Fixed an error that manifests only when using C++ (specifically, + modern versions of g++) to compile drivers in 'test' (and likely most + other application code that #includes blis.h. Thanks to Ajay Panyala + for reporting this issue (#374). + +commit d626112b8d5302f9585fb37a8e37849747a2a317 +Author: Field G. Van Zee +Date: Wed Jan 15 13:27:02 2020 -0600 + + Removed sorting on LDFLAGS in common.mk (#373). + + Details: + - Removed a line of code in common.mk that passed LDFLAGS through the + sort function. The purpose was not to sort the contents, but rather + to remove duplicates. However, there is valid syntax in a string of + linker flags that, when sorted, yields different/broken behavior. + So I've removed the line in common.mk that sorts LDFLAGS. Also, for + future use, I've added a new function, rm-dupls, that removes + duplicates without sorting. (This function was based on code from a + stackoverflow thread that is linked to in the comments for that + code.) Thanks to Isuru Fernando for reporting this issue (#373). + +commit e67deb22aaeab5ed6794364520190936748ef272 +Author: Field G. Van Zee +Date: Tue Jan 14 16:01:34 2020 -0600 + + CHANGELOG update (0.6.1) + +commit 10949f528c5ffc5c3a2cad47fe16a802afb021be (tag: 0.6.1) +Author: Field G. Van Zee +Date: Tue Jan 14 16:01:33 2020 -0600 + + Version file update (0.6.1) + +commit 5db8e710a2baff121cba9c63b61ca254a2ec097a +Author: Field G. Van Zee +Date: Tue Jan 14 15:59:59 2020 -0600 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated ReleaseNotes.md in preparation for next version. + +commit cde4d9d7a26eb51dcc5a59943361dfb8fda45dea +Author: Field G. Van Zee +Date: Tue Jan 14 15:19:25 2020 -0600 + + Removed 'attic/windows' (to prevent confusion). + + Details: + - Finally removed 'attic/windows' and its contents. This directory once + contained "proto" Windows support for BLIS, but we've since moved on + to (thanks to Isuru Fernando) providing Windows DLL support via + AppVeyor's build artifacts. Furthermore, since 'windows' was the only + subdirectory within 'attic', the directory path would show up in + GitHub's listing at https://github.com/flame/blis, which probably led + to someone being confused about how BLIS provides Windows support. I + assume (but don't know for sure) that nobody is using these files, so + this is admittedly a case of shoot first and ask questions later. + +commit 7d3407d4681c6449f4bbb8ec681983700ab968f3 +Author: Field G. Van Zee +Date: Tue Jan 14 15:17:53 2020 -0600 + + CREDITS file update. + +commit f391b3e2e7d11a37300d4c8d3f6a584022a599f5 +Author: Dave Love +Date: Mon Jan 6 20:15:48 2020 +0000 + + Fix parsing in vpu_count on workstation SKX (#351) + + * Fix parsing in vpu_count on workstation SKX + + * Document Skylake-X as Haswell for single FMA + + * Update vpu_count for Skylake and Cascade Lake models + + * Support printing the configuration selected, controlled by the environment + + Intended particularly for diagnosing mis-selection of SKX through + unknown, or incorrect, number of VPUs. + + * Move bli_log outside the cpp condition, and use it where intended + + * Add Fixme comment (Skylake D) + + * Mostly superficial edits to commits towards #351. + + Details: + - Moved architecture/sub-config logging-related code from bli_cpuid.c + to bli_arch.c, tweaked names, and added more set/get layering. + - Tweaked log messages output from bli_cpuid_is_skx() in bli_cpuid.c. + - Content, whitespace changes to new bullet in HardwareSupport.md that + relates to single-VPU Skylake-Xs. + + * Fix comment typos + + Co-authored-by: Field G. Van Zee + +commit 5ca1a3cfc1c1cc4dd9da6a67aa072ed90f07e867 +Author: Field G. Van Zee +Date: Mon Jan 6 12:29:12 2020 -0600 + + Fixed 'configure' breakage introduced in 6433831. + + Details: + - Added a missing 'fi' (endif) keyword to a conditional block added in + the configure script in commit 6433831. + +commit e7431b4a834ef4f165c143f288585ce8e2272a23 +Author: Field G. Van Zee +Date: Mon Jan 6 12:01:41 2020 -0600 + + Updated 1m draft article link in README.md. + +commit 6433831cc3988ad205637ebdebcd6d8f7cfcf148 +Author: Jeff Hammond +Date: Fri Jan 3 17:52:49 2020 -0800 + + blacklist ICC 18 for knl/skx due to test failures + + Signed-off-by: Jeff Hammond + +commit af3589f1f98781e3a94a8f9cea8d5ea6f155f7d2 +Author: Jeff Hammond +Date: Fri Jan 3 13:23:24 2020 -0800 + + blacklist Intel 19+ + + Signed-off-by: Jeff Hammond + +commit 60de939debafb233e57fd4e804ef21b6de198caf +Author: Jeff Hammond +Date: Wed Jan 1 21:30:38 2020 -0800 + + fix link to docs + + the comment contains an incorrect link, which is trivially fixed here. + + @fgvanzee I hope you don't mind that I committed directly to master but this cannot break anything. + +commit 52711073789b6b84eb99bb0d6883f457ed3fcf80 +Author: Field G. Van Zee +Date: Mon Dec 16 16:30:26 2019 -0600 + + Fixed bugs in cblas_sdsdot(), sdsdot_(). + + Details: + - Fixed a bug in sdsdot_sub() that redundantly added the "alpha" scalar, + named 'sb'. This value was already being added by the underlying + sdsdot_() function. Thus, we no longer add 'sb' within sdsdot_sub(). + Thanks to Simon Lukas Märtens for reporting this bug via #367. + - Fixed a second bug in order of typecasting intermediate products in + sdsdot_(). Previously, the "alpha" scalar was being added after the + "outer" typecast to float. However, the operation is supposed to first + add the dot product to the (promoted) scalar and THEN downcast the sum + to float. Thanks to Devin Matthews for catching this bug. + +commit fe2560a4b1d8ef8d0a446df6002b1e7decc826e9 +Author: Field G. Van Zee +Date: Fri Dec 6 17:12:44 2019 -0600 + + Annoted missing thread-related symbols for export. + + Details: + - Added BLIS_EXPORT_BLIS annotation to function prototypes for + + bli_thrcomm_bcast() + bli_thrcomm_barrier() + bli_thread_range_sub() + + so that these functions are exported to shared libraries by default. + This (hopefully) fixes issue #366. Thanks to Kyungmin Lee for + reporting this bug. + - CREDITS file update. + +commit 2853825234001af8f175ad47cef5d6ff9b7a5982 +Merge: efa61a6c 61b1f0b0 +Author: Field G. Van Zee +Date: Fri Dec 6 16:06:46 2019 -0600 + + Merge branch 'master' into amd + +commit 61b1f0b0602faa978d9912fe58c6c952a33af0ac +Author: Nicholai Tukanov +Date: Wed Dec 4 14:18:47 2019 -0600 + + Add prototypes for POWER9 reference kernels (#365) + + Updates and fixes to power9 subconfig. + + Details: + - Register s,c,z reference gemm and trsm ukernels that assume elements + of B have been broadcast. + - Added prototypes for level-3 ukernels that assume elements of B have + been broadcast. Also added prototype for an spackm function that + employs a duplication/broadcast factor of 4. + - Register virtual gemmtrsm ukernels that work with broadcasting of B. + - Disable right-side hemm, symm, trmm, and trmm3 in bli_family_power9.h. + - Thanks to Nicholai Tukanov for providing these updates. + +commit efa61a6c8b1cfa48781fc2e4799ff32e1b7f8f77 +Author: Field G. Van Zee +Date: Fri Nov 29 16:17:04 2019 -0600 + + Added missing bli_l3_sup_thread_decorator() symbol. + + Details: + - Defined dummy versions of bli_l3_sup_thread_decorator() for Openmp + and pthreads so that those builds don't fail when performing shared + library linking (especially for Windows DLLs via AppVeyor). For now, + these dummy implementations of bli_l3_sup_thread_decorator() are + merely carbon-copies of the implementation provided for single- + threaded execution (ie: the one found in bli_l3_sup_decor_single.c). + Thus, an OpenMP or pthreads build will be able to use the gemmsup + code (including the new selective packing functionality), as it did + before 39fa7136, even though it will not actually employ any + multithreaded parallelism. + +commit 39fa7136f4a4e55ccd9796fb79ad5f121b872ad9 +Author: Field G. Van Zee +Date: Fri Nov 29 15:27:07 2019 -0600 + + Added support for selective packing to gemmsup. + + Details: + - Implemented optional packing for A or B (or both) within the sup + framework (which currently only supports gemm). The request for + packing either matrix A or matrix B can be made via setting + environment variables BLIS_PACK_A or BLIS_PACK_B (to any + non-zero value; if set, zero means "disable packing"). It can also + be made globally at runtime via bli_pack_set_pack_a() and + bli_pack_set_pack_b() or with individual rntm_t objects via + bli_rntm_set_pack_a() and bli_rntm_set_pack_b() if using the expert + interface of either the BLIS typed or object APIs. (If using the + BLAS API, environment variables are the only way to communicate the + packing request.) + - One caveat (for now) with the current implementation of selective + packing is that any blocksize extension registered in the _cntx_init + function (such as is currently used by haswell and zen subconfigs) + will be ignored if the affected matrix is packed. The reason is + simply that I didn't get around to implementing the necessary logic + to pack a larger edge-case micropanel, though this is entirely + possible and should be done in the future. + - Spun off the variant-choosing portion of bli_gemmsup_ref() into + bli_gemmsup_int(), in bli_l3_sup_int.c. + - Added new files, bli_l3_sup_packm_a.c, bli_l3_sup_packm_b.c, along + with corresponding headers, in which higher-level packm-related + functions are defined for use within the sup framework. The actual + packm variant code resides in bli_l3_sup_packm_var.c. + - Pass the following new parameters into var1n and var2m: packa, packb + bool_t's, pointer to a rntm_t, pointer to a cntl_t (which is for now + always NULL), and pointer to a thrinfo_t* (which for nowis the address + of the global single-threaded packm thread control node). + - Added panel strides ps_a and ps_b to the auxinfo_t structure so that + the millikernel can query the panel stride of the packed matrix and + step through it accordingly. If the matrix isn't packed, the panel + stride of interest for the given millikernel will be set to the + appropriate value so that the mkernel may step through the unpacked + matrix as it normally would. + - Modified the rv_6x8m and rv_6x8n millikernels to read the appropriate + panel strides (ps_a and ps_b, respectively) instead of computing them + on the fly. + - Spun off the environment variable getting and setting functions into + a new file, bli_env.c (with a corresponding prototype header). These + functions are now used by the threading infrastructure (e.g. + BLIS_NUM_THREADS, BLIS_JC_NT, etc.) as well as the selective packing + infrastructure (e.g. BLIS_PACK_A, BLIS_PACK_B). + - Added a static initializer for mem_t objects, BLIS_MEM_INITIALIZER. + - Added a static initializer for pblk_t objects, BLIS_PBLK_INITIALIZER, + for use within the definition of BLIS_MEM_INITIALIZER. + - Moved the global_rntm object to bli_rntm.c and extern it where needed. + This means that the function bli_thread_init_rntm() was renamed to + bli_rntm_init_from_global() and relocated accordingly. + - Added a new bli_pack.c function, which serves as the home for + functions that manage the pack_a and pack_b fields of the global + rntm_t, including from environment variables, just as we have + functions to manage the threading fields of the global rntm_t in + bli_thread.c. + - Reorganized naming for files in frame/thread, which mostly involved + spinning off the bli_l3_thread_decorator() functions into their own + files. This change makes more sense when considering the further + addition of bli_l3_sup_thread_decorator() functions (for now limited + only to the single-threaded form found in the _single.c file). + - Explicitly initialize the reference sup handlers in both + bli_cntx_init_haswell.c and bli_cntx_init_zen.c so that it's more + obvious how to customize to a different handler, if desired. + - Removed various snippets of disabled code. + - Various comment updates. + +commit bbb21fd0a9be8c5644bec37c75f9396eeeb69e48 +Author: Field G. Van Zee +Date: Thu Nov 21 18:15:16 2019 -0600 + + Tweaked SIAM/SC Best Prize language in README.md. + +commit 043366f92d5f5f651d5e3371ac3adb36baf4adce +Author: Field G. Van Zee +Date: Thu Nov 21 18:13:51 2019 -0600 + + Fixed typo in previous commit (SIAM/SC prize). + +commit 05a4d583e65a46ff2a1100ab4433975d905d91f9 +Author: Field G. Van Zee +Date: Thu Nov 21 18:12:24 2019 -0600 + + Added SIAM/SC prize to "What's New" in README.md. + +commit 881b05ecd40c7bc0422d3479a02a28b1cb48383f +Author: Field G. Van Zee +Date: Thu Nov 21 16:34:27 2019 -0600 + + Fixed blastest failure for 'generic' subconfig. + + Details: + - Fixed a subtle and complicated bug that only manifested via the BLAS + test drivers in the generic subconfiguration, and possibly any other + subconfiguration that did not register complex-domain gemm ukernels, + or registered ONLY real-domain ukernels as row-preferential. This is + a long story, but it boils down to an exception to the "transpose the + operation to bring storage of C into agreement with ukernel pref" + optimization in bli_hemm_front.c and bli_symm_front.c sabotaging the + proper functioning of the 1m method, but only when the imaginary + component of beta is zero. See the comments in issue #342 for more + details. Thanks to Dave Love for identifying the commit in which this + bug was introduced, and other feedback related to this bug. + +commit 0c7165fb01cdebbc31ec00124d446161b289942f +Author: Field G. Van Zee +Date: Thu Nov 14 16:48:14 2019 -0600 + + Fixed obscure bug in bli_acquire_mpart_[mn]dim(). + + Details: + - Fixed a bug in bli_acquire_mpart_mdim(), bli_acquire_mpart_ndim(), + and bli_acquire_mpart_mndim() that allowed the use of a blocksize b + that is too large given the current row/column index (i.e., the i/j + argument) and the size of the dimension being partitioned (i.e., the + m/n argument). This bug only affected backwards partitioning/motion + through the dimension and was the result of a misplaced conditional + check-and-redirect to the backwards code path. It should be noted + that this bug was discovered not because it manifested the way it + could (thanks to the callers in BLIS making sure to always pass in + the "correct" blocksize b), but could have manifested if the + functions were used by 3rd party callers. Thanks to Minh Quan Ho for + reporting the bug via issue #363. + +commit fb8bef9982171ee0f60bc39e41a33c4d31fd59a9 +Author: Field G. Van Zee +Date: Thu Nov 14 13:05:28 2019 -0600 + + Fixed copy-paste bug in bli_spackm_6xk_bb4_ref(). + + Details: + - Fixed a copy-paste bug in the new bli_spackm_6xk_bb4_ref() that + manifested as failures in single-precision real level-3 operations. + Also replaced the duplication factor constants with a const-qualifed + varialbe, dfac, so that this won't happen again. + - Changed NC for single-precision real from 4080 to 8160 so that the + packed matrix B will have the same byte footprint in both single + and double real. + +commit 8f399c89403d5824ba767df1426706cf2d19d0a7 +Author: Field G. Van Zee +Date: Tue Nov 12 15:32:57 2019 -0600 + + Tweaked/added notes to docs/Multithreading.md. + + Details: + - Added language to docs/Multithreading.md cautioning the reader about + the nuances of setting multithreading parameters via the manual and + automatic ways simultaneously, and also about how these parameters + behave when multithreading is disabled at configure-time. These + changes are an attempt to address the issues that arose in issue #362. + Thanks to Jérémie du Boisberranger for his feedback on this topic. + - CREDITS file update. + +commit bdc7ee3394500d8e5b626af6ff37c048398bb27e +Author: Field G. Van Zee +Date: Mon Nov 11 15:47:17 2019 -0600 + + Various fixes to support packing duplication in B. + + Details: + - Added cpp macros to trmm and trmm3 front-ends to optionally force + those operations to be cast so the structured matrix is on the left. + symm and hemm already had such macros, but these too were renamed so + that the macros were individual to the operation. We now have four + such macros: + #define BLIS_DISABLE_HEMM_RIGHT + #define BLIS_DISABLE_SYMM_RIGHT + #define BLIS_DISABLE_TRMM_RIGHT + #define BLIS_DISABLE_TRMM3_RIGHT + Also, updated the comments in the symm and hemm front-ends related to + the first two macro guards, and added corresponding comments to the + trmm and trmm3 front-ends for the latter two guards. (They all + functionally do the same thing, just for their specific operations.) + Thanks to Jeff Hammond for reporting the bugs that led me to this + change (via #359). + - Updated config/old/haswellbb subconfiguration (used to debug issues + related to duplicating B during packing) to register: a packing + kernel for single-precision real; gemmbb ukernels for s, c, and z; + trsmbb ukernels for s, c, and z; gemmtrsmbb virtual ukrnels for s, c + and z; and to use non-default cache and register blocksizes for s, c, + and z datatypes. Also declared prototypes for all of the gemmbb, + trsmbb, and gemmtrsmbb ukernel functions within the + bli_cntx_init_haswellbb() function. This should, once applied to the + power9 configuration, fix the remaining issues in #359. + - Defined bli_spackm_6xk_bb4_ref(), which packs single reals with a + duplication factor of 4. This function is defined in the same file as + bli_dpackm_6xk_bb2_ref() (bli_packm_cxk_bb_ref.c). + +commit 0eb79ca8503bd7b237994335b9687457227d3290 +Author: Field G. Van Zee +Date: Fri Nov 8 14:48:48 2019 -0600 + + Avoid unused variable warning in lread.c (#356). + + Details: + - Replaced the line + + f = f; + + with + + ( void )f; + + for the unused variable 'f' in blastest/f2c/lread.c. (Hopefully) + addresses issue #356, but since we don't use xlc who knows. Thanks + to Jeff Hammond for reporting this. + +commit f377bb448512f0b578263387eed7eaf8f2b72bb7 +Author: Jérôme Duval +Date: Thu Nov 7 23:39:29 2019 +0100 + + Add Haiku to the known OS list (#361) + +commit e29b1f9706b6d9ed798b7f6325f275df4e6be973 +Author: Field G. Van Zee +Date: Tue Nov 5 17:15:19 2019 -0600 + + Fixed failing testsuite gemmtrsm_ukr for power9. + + Details: + - Added code that fixes false failures in the gemmtrsm_ukr module of the + testsuite. The tests were failing because the computation (bli_gemv()) + that performs the numerical check was not able to properly travserse + the matrix operands bx1 and b11 that are views into the micropanel of + B, which has duplicated/broadcast elements under the power9 subconfig. + (For example, a micropanel of B with duplication factor of 2 needs to + use a column stride of 2; previously, the column stride was being + interpreted as 1.) + - Defined separate bli_obj_set_row_stride() and bli_obj_set_col_stride() + static functions in bli_obj_macro_defs.h. (Previously, only the + function bli_obj_set_strides() was defined. Amazing to think that we + got this far without these former functions.) + - Updated/expounded upon comments. + +commit 49177a6b9afcccca5b39a21c6fd8e243525e1505 +Author: Field G. Van Zee +Date: Mon Nov 4 18:09:37 2019 -0600 + + Fixed latent testsuite ukr module bugs for power9. + + Details: + - Fixed a latent bug in the testsuite ukernel modules (gemm, trsm, and + gemmtrsm) that only manifested once we began running with parameters + that mimic those of power9. The problem was rooted in the way those + modules were creating objects (and thus allocating memory) for the + micropanel operands to the microkernel being tested. Since power9 + duplicates/broadcasts elements of B in memory, we needed an easy way + of asking for more than one storage element per logical element in + the matrix. I incorrectly expressed this as: + + bli_obj_create( datatype, k, n, ldbp, 1, &bp ); + + The problem here is that bli_obj_create() is exceedingly efficient + at calculating the size it passes to malloc() and doesn't allocate a + full leading dimension's worth of elements for the last column (or + row, in this example). This would normally not bother anyone since + you're not supposed to access that memory anyway. But here, my + attempted "hack" for getting extra elements was insufficient, and + needed to be changed to: + + bli_obj_create( datatype, k, ldbp, ldbp, 1, &bp ); + + That is, the extra elements needed to be baked into the dimensions of + the matrix object in order to have the intended effect on the number + of elements actually allocated. Thanks to Jeff Hammond for reporting + this bug. + - Fixed a typically harmless memory leak in the aforementioned test + modules (the objects for the packed micropanels were not being freed). + - Updated/expanded a common comment across all three ukr test modules. + +commit c84391314d4f1b3f73d868f72105324e649f2a72 +Author: Field G. Van Zee +Date: Mon Nov 4 13:57:12 2019 -0600 + + Reverted minor temp/wspace changes from b426f9e. + + Details: + - Added missing license header to bli_pwr9_asm_macros_12x6.h. + - Reverted temporary changes to various files in 'test' and 'testsuite' + directories. + - Moved testsuite/jobscripts into testsuite/old. + - Minor whitespace/comment changes across various files. + +commit 4870260f6b8c06d2cc01b7147d7433ddee213f7f +Author: Jeff Hammond +Date: Mon Nov 4 11:55:47 2019 -0800 + + blacklist GCC 5 and older for POWER9 (#360) + +commit b426f9e04e5499c6f9c752e49c33800bfaadda4c +Author: Nicholai Tukanov +Date: Fri Nov 1 17:57:03 2019 -0500 + + POWER9 DGEMM (#355) + + Implemented and registered power9 dgemm ukernel. + + Details: + - Implemented 12x6 dgemm microkernel for power9. This microkernel + assumes that elements of B have been duplicated/broadcast during the + packing step. The microkernel uses a column orientation for its + microtile vector registers and thus implements column storage and + general stride IO cases. (A row storage IO case via in-register + transposition may be added at a future date.) It should be noted that + we recommend using this microkernel with gcc and *not* xlc, as issues + with the latter cropped up during development, including but not + limited to slightly incompatible vector register mnemonics in the GNU + extended inline assembly clobber list. + +commit 58102aeaa282dc79554ed045e1b17a6eda292e15 +Merge: 52059506 b9bc222b +Author: Field G. Van Zee +Date: Mon Oct 28 17:58:31 2019 -0500 + + Merge branch 'amd' + +commit 52059506b2d5fd4c3738165195abeb356a134bd4 +Author: Field G. Van Zee +Date: Wed Oct 23 15:26:42 2019 -0500 + + Added "How to Download BLIS" section to README.md. + + Details: + - Added a new section to the README.md, just prior to the "Getting + Started" section, titled "How to Download BLIS". This section details + the user's options for obtaining BLIS and lays out four common ways + of downloading the library. Thanks to Jeff Diamond for his feedback + on this topic. + +commit e6f0a96cc59aef728470f6850947ba856148c38a +Author: Field G. Van Zee +Date: Mon Oct 14 17:05:39 2019 -0500 + + Updated README.md to ack Facebook as funder. + +commit b9bc222bfc3db4f9ae5d7b3321346eed70c2c3fb +Author: Field G. Van Zee +Date: Mon Oct 14 16:38:15 2019 -0500 + + Call bli_syrk_small() before error checking. + + Details: + - In bli_syrk_front(), moved the conditional call to bli_syrk_check() + (if error checking is enabled) and the conditional scaling of C by + beta (if alpha is zero) so that they occur after, instead of before, + the call to bli_syrk_small(). This sequencing now matches that of + bli_gemm_small() in bli_gemm_front() and bli_trsm_small() in + bli_trsm_front(). + +commit f0959a81dbcf30d8a1076d0a6348a9835079d31a +Author: Field G. Van Zee +Date: Mon Oct 14 15:46:28 2019 -0500 + + When manual config is blacklisted, output error. + + Details: + - Fixed and adjusted the logic in configure so that a more informative + error message is output when a user runs './configure ... ' and + is present in the configuration blacklist. Previously, this + particular set of conditions would result in the message: + + 'user-specified configuration '' is NOT registered! + + That is, the error message mis-identified the targeted configuration + as the empty string, and (more importantly) mis-identifies the + problem. Thanks to Tze Meng Low for reporting this issue. + - Fixed a nearby error messages somewhat unrelated to the issue above. + Specifically, the wrong string was being printed when the error + message was identifying an auto-detected configuration that did not + appear to be registered. + +commit 6218ac95a525eefa8921baf8d0d7057dfacebe9c +Merge: 0016d541 a617301f +Author: Field G. Van Zee +Date: Fri Oct 11 11:53:51 2019 -0500 + + Merge branch 'master' into amd + +commit 0016d541e6b0da617b1fae6612d2b314901b7a75 +Author: Field G. Van Zee +Date: Fri Oct 11 11:09:44 2019 -0500 + + Changed -march=znver2 to =znver1 for clang on zen2. + + Details: + - In config/zen2/make_defs.mk, changed the -march= flag so that + -march=znver1 is used instead of -march=znver2 when CC_VENDOR is + clang. (The gcc branch attempts to differentiate between various + versions, but the equivalent version cutoffs for clang are not + yet known by us, so we have to use a single flag for all versions + of clang. Hopefully -march=znver1 is new enough. If not, we'll + fall back to -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp.) + This issue was discovered thanks to AppVeyor. + +commit e94a0530e5ac4c78a18f09105f40003be2b517f7 +Author: Field G. Van Zee +Date: Fri Oct 11 10:48:27 2019 -0500 + + Corrected zen NC that was non-multiple of NR. + + Details: + - Updated an incorrectly set cache blocksize NC for single real within + config/zen/bli_cntx_init_zen.c that was non a multiple of the + corresponding value of NR. This issue, which was caught by Travis CI, + was introduced in 29b0e1e. + +commit a2ffac752076bf55eb8c1fe2c5da8d9104f1f85b +Merge: 1cfe8e25 29b0e1ef +Author: Field G. Van Zee +Date: Fri Oct 11 10:31:18 2019 -0500 + + Merge branch 'amd-master' into amd + +commit 29b0e1ef4e8b84ce76888d73c090009b361f1306 +Merge: 1cfe8e25 fdce1a56 +Author: Field G. Van Zee +Date: Fri Oct 11 10:24:24 2019 -0500 + + Code review + tweaks to AMD's AOCL 2.0 PR (#349). + + Details: + - NOTE: This is a merge commit of 'master' of git://github.com/amd/blis + into 'amd-master' of flame/blis. + - Fixed a bug in the downstream value of BLIS_NUM_ARCHS, which was + inadvertantly not incremented when the Zen2 subconfiguration was + added. + - In bli_gemm_front(), added a missing conditional constraint around the + call to bli_gemm_small() that ensures that the computation precision + of C matches the storage precision of C. + - In bli_syrk_front(), reorganized and relocated the notrans/trans logic + that existed around the call to bli_syrk_small() into bli_syrk_small() + to minimize the calling code footprint and also to bring that code + into stylistic harmony with similar code in bli_gemm_front() and + bli_trsm_front(). Also, replaced direct accessing of obj_t fields with + proper accessor static functions (e.g. 'a->dim[0]' becomes + 'bli_obj_length( a )'). + - Added #ifdef BLIS_ENABLE_SMALL_MATRIX guard around prototypes for + bli_gemm_small(), bli_syrk_small(), and bli_trsm_small(). This is + strictly speaking unnecessary, but it serves as a useful visual cue to + those who may be reading the files. + - Removed cpp macro-protected small matrix debugging code from + bli_trsm_front.c. + - Added a GCC_OT_9_1_0 variable to build/config.mk.in to facilitate gcc + version check for availability of -march=znver2, and added appropriate + support to configure script. + - Cleanups to compiler flags common to recent AMD microarchitectures in + config/zen/amd_config.mk, including: removal of -march=znver1 et al. + from CKVECFLAGS (since the -march flag is added within make_defs.mk); + setting CRVECFLAGS similarly to CKVECFLAGS. + - Cleanups to config/zen/bli_cntx_init_zen.c. + - Cleanups, added comments to config/zen/make_defs.mk. + - Cleanups to config/zen2/make_defs.mk, including making use of newly- + added GCC_OT_9_1_0 and existing GCC_OT_6_1_0 to choose the correct + set of compiler flags based on the version of gcc being used. + - Reverted downstream changes to test/test_gemm.c. + - Various whitespace/comment changes. + +commit a617301f9365ac720ff286514105d1b78951368b +Author: Field G. Van Zee +Date: Tue Oct 8 17:14:05 2019 -0500 + + Updates to docs/CodingConventions.md. + +commit 171f10069199f0cd280f18aac184546bd877c4fe +Merge: 702486b1 05d58edf +Author: Field G. Van Zee +Date: Fri Oct 4 11:18:23 2019 -0500 + + Merge remote-tracking branch 'loveshack/emacs' + +commit 702486b12560b5c696ba06de9a73fc0d5107ca44 +Author: Field G. Van Zee +Date: Wed Oct 2 16:35:41 2019 -0500 + + Removed stray FAQ section introduced in 1907000. + +commit 1907000ad6ea396970c010f07ae42980b7b14fa0 +Author: Field G. Van Zee +Date: Wed Oct 2 16:31:54 2019 -0500 + + Updated to FAQ (AMD-related questions). + + Details: + - Added a couple potential frequently-asked questions/answers releated + to AMD's fork of BLIS. + - Updated existing answers to other questions. + +commit 834f30a0dad808931c9d80bd5831b636ed0e1098 +Author: Field G. Van Zee +Date: Wed Oct 2 12:45:56 2019 -0500 + + Mention mixeddt paper in docs/MixedDatatypes.md. + +commit 05d58edfe0ea9279971d74f17a5f7a69c4672ed5 +Author: Dave Love +Date: Wed Oct 2 10:33:44 2019 +0100 + + Note .dir-locals.el in docs + +commit 531110c339f199a4d165d707c988d89ab4f5bfe8 +Author: Dave Love +Date: Wed Oct 2 10:16:22 2019 +0100 + + Modify Emacs config + Confine it to cc-mode and add comment-start/end. + +commit 4bab365cab98202259c70feba6ec87408cba28d8 +Author: Dave Love +Date: Tue Oct 1 19:22:47 2019 +0000 + + Add .dir-locals.el for Emacs (#348) + + A minimal version that could probably do with extending, but at least + gets the indentation roughly right. + +commit 4ec8dad66b3d37b0a2b47d19b7144bb62d332622 +Author: Dave Love +Date: Thu Sep 26 16:27:53 2019 +0100 + + Add .dir-locals.el for Emacs + + A minimal version that could probably do with extending, but at least + gets the indentation roughly right. + +commit bc16ec7d1e2a30ce4a751255b70c9cbe87409e4f +Author: Field G. Van Zee +Date: Mon Sep 23 15:37:33 2019 -0500 + + Set execute bits of shared library at install-time. + + Details: + - Modified the 0644 octal code used during installation of shared + libraries to 0755 (for Linux/OSX only). Thanks to Adam J. Stewart + for reporting this issue via #343. + - CREDITS file update. + +commit c60db26aee9e7b4e5d0b031b0881e58d23666b53 +Author: Field G. Van Zee +Date: Tue Sep 17 18:04:17 2019 -0500 + + Fixed bad loop counter in bli_[cz]scal2bbs_mxn(). + + Details: + - Fixed a typo in the loop counter for the 'd' (duplication) dimension + in the complex macros of frame/include/level0/bb/bli_scal2bbs_mxn.h. + They shouldn't be used by anyone yet, but thankfully clang via + AppVeyor spit out warnings that alerted me to the issue. + +commit c766c81d628f0451d8255bf5e4b8be0a4ef91978 +Author: Field G. Van Zee +Date: Tue Sep 17 18:00:29 2019 -0500 + + Added missing schema arg to knl packm kernels. + + Details: + - Added the pack_t schema argument to the knl packm kernel functions. + This change was intended for inclusion in 31c8657. (Thank you SDE + + Travis CI.) + +commit 31c8657f1d6d8f6efd8a73fd1995e995fc56748b +Author: Field G. Van Zee +Date: Tue Sep 17 17:42:10 2019 -0500 + + Added support for pre-broadcast when packing B. + + Details: + - Added support for being able to duplicate (broadcast) elements in + memory when packing matrix B (ie: the left-hand operand) in level-3 + operations. This turns out advantageous for some architectures that + can afford the cost of the extra bandwidth and somehow benefit from + the pre-broadcast elements (and thus being able to avoid using + broadcast-style load instructions on micro-rows of B in the gemm + microkernel). + - Support optionally disabling right-side hemm and symm. If this occurs, + hemm_r is implemented in terms of hemm_l (and symm_r in terms of + symm_l). This is needed when broadcasting during packing because the + alternative--supporting the broadcast of B while also allowing matrix + B to be Hermitian/symmetric--would be an absolute mess. + - Support alignment factors for packed blocks of A, B, and C separately + (as well as for general-purpose buffers). In addition, we support + byte offsets from those alignment values (which is different from + aligning by align+offset bytes to begin with). The default alignment + values are BLIS_PAGE_SIZE in all four cases, with the offset values + defaulting to zero. + - Pass pack_t schema into bli_?packm_cxk() so that it can be then passed + into the packm kernel, where it will be needed by packm kernels that + perform broadcasts of B, since the idea is that we *only* want to + broadcast when packing micropanels of B and not A. + - Added definition for variadic bli_cntx_set_l3_vir_ukrs(), which can be + used to set custom virtual level-3 microkernels in the cntx_t, which + would typically be done in the bli_cntx_init_*() function defined in + the subconfiguration of interest. + - Added a "broadcast B" kernel function for use with NP/NR = 12/6, + defined in in ref_kernels/1m/bli_packm_cxk_bb_ref.c. + - Added a gemm, gemmtrsm, and trsm "broadcast B" reference kernels + defined in ref_kernels/3/bb. (These kernels have been tested with + double real with NP/NR = 12/6.) + - Added #ifndef ... #endif guards around several macro constants defined + in frame/include/bli_kernel_macro_defs.h. + - Defined a few "broadcast B" static functions in + frame/include/level0/bb for use by "broadcast B"-style packm reference + kernels. For now, only the real domain kernels are tested and fully + defined. + - Output the alignment and offset values for packed blocks of A and B + in the testsuite's "BLIS configuration info" section. + - Comment updates to various files. + - Bumped so_version to 3.0.0. + +commit fd9bf497cd4ff73ccdfc030ba037b3cb2f1c2fad +Author: Field G. Van Zee +Date: Tue Sep 17 15:45:24 2019 -0500 + + CREDITS file update. + +commit 6c8f2d1486ce31ad3c2083e5c2035acfd4409a43 +Author: ShmuelLevine +Date: Tue Sep 17 16:43:46 2019 -0400 + + Fix description for function bli_*pxby2v (#340) + + Fix typo in BLISTypedAPI.md for bli_?axpy2v() description. + +commit b5679c1520f8ae7637b3cc2313133461f62398dc +Author: Field G. Van Zee +Date: Tue Sep 17 14:00:37 2019 -0500 + + Inserted Multithreading links into BuildSystem.md. + + Details: + - Inserted brief disclaimers about default disabled multithreading + and default single-threadedness to BuildSystem.md along with links to + the Multithreading.md document. Thanks to Jeff Diamond for suggesting + these additions. + - Trivial reword of sentence regarding automatically-detected + architectures. + +commit f4f5170f8482c94132832eb3033bc8796da5420b +Author: Isuru Fernando +Date: Wed Sep 11 07:34:48 2019 -0500 + + Update README.md (#338) + +commit 1cfe8e2562e5e50769468382626ce36b734741c1 +Author: Field G. Van Zee +Date: Thu Sep 5 16:08:30 2019 -0500 + + Reimplemented bli_cpuid_query() for ARM. + + Details: + - Rewrote bli_cpuid_query() for ARM architectures to use stdio-based + functions such as fopen() and fgets() instead of popen(). The new code + does more or less the same thing as before--searches /proc/cpuinfo for + various strings, which are then parsed in order to determine the + model, part number, and features. Thanks to Dave Love for suggesting + this change in issue #335. + +commit 7c7819145740e96929466a248d6375d40e397e19 +Author: Devin Matthews +Date: Fri Aug 30 16:52:09 2019 -0500 + + Always use sqsumv to compute normfv. (#334) + + * Always use sqsumv to compute normfv on MacOS. + + * Unconditionally disable the "dot trick" in normfv. + + * Added explanatory comment to normfv definition. + + Details: + - Added a comment above the unconditional disabling of the dotv-based + implementation to normfv. Thanks to Roman Yurchak, Devin Matthews, + and Isuru Fernando in helping with this improvement. + - CREDITS file update. + +commit 80e6c10b72d50863b4b64d79f784df7befedfcd1 +Author: Field G. Van Zee +Date: Thu Aug 29 12:12:08 2019 -0500 + + Added reproduction section to Performance docs. + + Details: + - Added section titled "Reproduction" to both Performance.md and + PerformanceSmall.md that briefly nudges the motivated reader in the + right direction if he/she wishes to run the same performance + benchmarks used to produce the graphs shown in those documents. + Thanks to Dave Love for making this suggestion. + +commit 14cb426414856024b9ae0f84ac21efcc1d329467 +Author: Field G. Van Zee +Date: Wed Aug 28 17:04:33 2019 -0500 + + Updated OpenBLAS, Eigen sup results. + + Details: + - Updated the results shown in docs/PerformanceSmall.md for OpenBLAS and + Eigen. + +commit b02e0aae8ce2705e91023b98ed416cd05430a78e +Author: Field G. Van Zee +Date: Tue Aug 27 14:37:46 2019 -0500 + + Updated test drivers to iterate backwards. + + Details: + - Updated test driver source in test, test/3, test/1m4m, and + test/mixeddt to iterate through the problem space backwards. This + can help avoid certain situations where the CPU frequency does not + immediately throttle up to its maximum. Thanks to Robert van de + Geijn for recommending this fix (originally made to test/sup drivers + in 57e422a). + - Applied off-by-one matlab output bugfix from b6017e5 to test drivers + in test, test/3, test/1m4m, and test/mixeddt directories. + +commit b6017e53f4b26c99b14cdaa408351f11322b1e80 +Author: Field G. Van Zee +Date: Tue Aug 27 14:18:14 2019 -0500 + + Bugfix of output text + tweaks to test/sup driver. + + Details: + - Fixed an off-by-one bug in the output of matlab row indices in + test/sup/test_gemm.c that only manifested when the problem size + increment was equal to 1. + - Disabled the building of rrc, rcr, rcc, crr, crc, and ccr storage + combinations for blissup drivers in test/sup. This helps make the + building of drivers complete sooner. + - Trivial changes to test/sup/runme.sh. + +commit 138d403b6bb15e687a3fe26d3d967b8ccd1ed97b +Author: Devin Matthews +Date: Mon Aug 26 18:11:27 2019 -0500 + + Use -funsafe-math-optimizations and -ffp-contract=fast for all reference kernels when using gcc or clang. (#331) + +commit d5a05a15a7fcc38fb2519031dcc62de8ea4a530c +Author: Field G. Van Zee +Date: Mon Aug 26 16:54:31 2019 -0500 + + Cropped whitespace from new sup graphs. + + Details: + - Previously forgot crop whitespace from the new .png graphs + added/updated in docs/graphs/sup. + +commit a6c80171a353db709e43f9e6e7a3da87ce4d17ed +Author: Field G. Van Zee +Date: Mon Aug 26 16:51:31 2019 -0500 + + Fixed contents links in docs/PerformanceSmall.md. + + Details: + - Corrected links in contents section of docs/PerformanceSmall.md, + which were erroneously directing readers to the corresponding + sections of docs/Performance.md. + +commit 40781774df56a912144ef19cc191ed626a89f0de +Author: Field G. Van Zee +Date: Mon Aug 26 16:47:37 2019 -0500 + + Updated sup performance graphs with libxsmm. + + Details: + - Added libxsmm to column-stored sup graphs presented in + docs/PerformanceSmall.md. + - Updated sup results for BLASFEO. + - Added sup results for Lonestar5 (Haswell). + - Addresses issue #326. + +commit bfddf671328e7e372ac7228f72ff2d9d8e03ae18 +Author: figual +Date: Mon Aug 26 12:01:33 2019 +0200 + + Fixed context registration for Cortex A53 (#329). + +commit 4a0a6e89c568246d14de4cc30e3ff35aac23d774 +Author: Field G. Van Zee +Date: Sat Aug 24 15:25:16 2019 -0500 + + Changed test/sup alpha to 1; test libxsmm+netlib. + + Details: + - Changed the value of alpha to 1.0 in test/sup/test_gemm.c. This is + needed because libxsmm currently only optimizes gemm operations where + alpha is unit (and beta is unit or zero). + - Adjusted the test/sup/Makefile to test libxsmm with netlib BLAS as its + fallback library. This is the library that will be called the + problem dimensions are deemed too large, or any other criteria for + optimization are not met. (This was done not because it is realistic, + but rather so that it would be very clear when libxsmm ceased handling + gemm calls internally when the data are graphed.) + +commit 7aa52b57832176c5c13a48e30a282e09ecdabf73 +Author: Field G. Van Zee +Date: Fri Aug 23 16:12:50 2019 -0500 + + Use libxsmm API in test/sup; add missing -ldl. + + Details: + - Switch the driver source in test/sup so that libxsmm_?gemm() is called + instead of ?gemm_() when compiling for / linking against libxsmm. + libxsmm's documentation isn't clear on whether it is even *trying* to + provide BLAS API compatibility, and I got tired of trying to figure it + out. + - Added missing -ldl in LDFLAGS when linking against libxsmm. + +commit 57e422aa168bee7416965265c93fcd4934cd7041 +Author: Field G. Van Zee +Date: Fri Aug 23 14:17:52 2019 -0500 + + Added libxsmm support to test/sup drivers. + + Details: + - Modified test/sup/Makefile to build drivers that test the performance + of skinny/small problems via libxsmm. + - Modified test/sup/runme.sh to run aforementioned drivers. + - Modified test/sup/test_gemm.c so that problem sizes are tested in + reverse order (from largest to smallest). This can help avoid certain + situations where the CPU frequency does not immediately throttle up + to its maximum. Thanks to Robert van de Geijn for recommending this + fix. + +commit 661681fe33978acce370255815c76348f83632bc +Merge: 2f387e32 ef0a1a0f +Author: Field G. Van Zee +Date: Thu Aug 22 14:29:50 2019 -0500 + + Merge branch 'master' of github.com:flame/blis + +commit 2f387e32ef5f9a17bafb5076dc9f66c38b52b32d +Author: Field G. Van Zee +Date: Thu Aug 22 14:27:30 2019 -0500 + + Added Eigen -march=native hack to perf docs. + + Details: + - Spell out the hack given to me by Sameer Agarwal in order to get Eigen + to build with -march=native (which is critically important for Eigen) + in docs/Performance.md and docs/PerformanceSmall.md. + +commit ef0a1a0faf683fe205f85308a54a77ffd68a9a6c +Author: Devin Matthews +Date: Wed Aug 21 17:40:24 2019 -0500 + + Update do_sde.sh (#330) + + * Update do_sde.sh + + Automatically accept SDE license and download directly from Intel + + * Update .travis.yml + + [ci skip] + + * Update .travis.yml + + Enable SDE testing for PRs. + +commit 0cd383d53a8c4a6871892a0395591ef5630d4ac0 +Author: Field G. Van Zee +Date: Wed Aug 21 13:39:05 2019 -0500 + + Corrected variable type and comment update. + + Details: + - Forgot to save all changes from bli_gemmtrsm4m1_ref.c before commit + in 8122f59. Fixed type mismatch and referenced github issue in + comment. + +commit 8122f59745db780987da6aa1e851e9e76aa985e0 +Author: Field G. Van Zee +Date: Wed Aug 21 13:22:12 2019 -0500 + + Pacify 'restrict' warning in gemmtrsm4m1 ref ukr. + + Details: + - Previously, some versions of gcc would complain that the same + pointer, one_r, is being passed in for both alpha and beta in the + fourth call to the real gemm ukernel in bli_gemmtrsm4m1_ref.c. This + is understandable since the compiler knows that the real gemm ukernel + qualifies all of its floating-point arguments (including alpha and + beta) with restrict. A small hack has been inserted into the file + that defines a new variable to store the value 1.0, which is now used + in lieu of one_r for beta in the fourth call to the real gemm ukernel, + which should pacify the compiler now. Thanks to Dave Love for + reporting this issue (#328) and for Devin Matthews for offering his + 'restrict' expertise. + +commit e8c6281f139bdfc9bd68c3b36e5e89059b0ead2e +Author: Field G. Van Zee +Date: Wed Aug 21 12:38:53 2019 -0500 + + Add -march support for specific gcc version ranges. + + Details: + - Added logic to configure that checks the version of the compiler + against known version ranges that could cause problems later in the + build process. For example, versions of gcc older than 4.9.0 use + different -march labels than version 4.9.0 or later + ('-march=corei7-avx' vs '-march=sandybridge', respectively). + Similarly, before 6.1, compilation on Zen was possible, but you + need to start with -march=bdver4 and then disable instruction sets + that were discarded during the transition from Excavator to Zen. So + now, configure substitutes 'yes'/'no' values into anchors in + config.mk.in, which sets various make variables (e.g. GCC_OT_4_9_0), + which can be accessed and branched upon by the various + configurations' make_defs.mk files when setting their compiler flags. + - Updated config/haswell/make_defs.mk to branch on GCC_OT_4_9_0. + - Updated config/sandybridge/make_defs.mk to branch on GCC_OT_4_9_0. + - Updated config/zen/make_defs.mk to branch on GCC_OT_6_1_0. + +commit e6ac4ebcb6e6a372820e7f509c0af3342966b84a +Author: Field G. Van Zee +Date: Tue Aug 20 13:49:47 2019 -0500 + + Added page size, source location to perf docs. + + Details: + - Added the page size, as returned via 'getconf -a | grep PAGE_SIZE', + and the location of the performance drivers to docs/Performance.md + (test/3) and docs/PerformanceSmall.md (test/sup). Thanks to Dave + Love for suggesting these additions in #325. + +commit fdce1a5648d69034fab39943100289323011c36f +Author: Meghana +Date: Wed Jul 24 15:04:41 2019 +0530 + + changed gcc version check condition from 'ifeq' to 'if greater or equal' + + Change-Id: Ie4c461867829bcc113210791bbefb9517e52c226 + +commit c9486e0c4f82cd9f58f5ceb71c0df039e9970a20 +Author: Meghana +Date: Wed Jul 24 09:45:17 2019 +0530 + + code to detect version of gcc and set flags accordingly for zen2 + + Change-Id: I29b0311d0000dee1a2533ee29941acf53f9e9f34 + +commit 54afe3dfe6828a1aff65baabbf14c98d92e50692 +Author: Field G. Van Zee +Date: Tue Jul 23 16:54:28 2019 -0500 + + Added "Education and Learning" ToC entry to README. + +commit 9f53b1ce7ac702e84e71801fe96986f6aa16040e +Author: Field G. Van Zee +Date: Tue Jul 23 16:50:35 2019 -0500 + + Added "Education and Learning" section to README. + + Details: + - Added a short section after the Intro of the README.md file titled + "Education and Learning" that directs interested readers to the + "LAFF-On Programming for High-Performance" massive open online course + (MOOC) hosted via edX. + +commit deda4ca8a094ee18d7c7c45e040e8ef180f33a48 +Author: Field G. Van Zee +Date: Mon Jul 22 13:59:05 2019 -0500 + + Added test/1m4m driver directory. + + Details: + - Added a new standalone test driver directory named '1m4m' that can + build and run performance experiments for BLIS 1m, 4m1a, assembly, + OpenBLAS, and the vendor library (MKL). This new driver directory + was used to regenerate performance results for the 1m paper. + - Added alternate (commented-out) cache blocksizes to + config/haswell/bli_cntx_init_haswell.c. These blocksizes tend to + work well on an a 12-core Intel Xeon E5-2650 v3. + +commit dcc0ce12fde4c6dca2b4764a1922a2ab19725867 +Author: Meghana +Date: Mon Jul 22 17:12:01 2019 +0530 + + Added a global Makefile for AMD architectures in config/zen folder + This Makefile(amd_config.mk) has all the flags that are common to EPYC series + + Change-Id: Ic02c60a8293ccdd37f0f292e631acd198e6895de + +commit af17bca26a8bd3dcbee8ca81c18d7b25de09c483 +Author: Field G. Van Zee +Date: Fri Jul 19 14:46:23 2019 -0500 + + Updated haswell MC cache blocksizes. + + Details: + - Updated the default MC cache blocksizes used by the haswell subconfig + for both row-preferential (the default) and column-preferential + microkernels. + +commit b5e9bce4dde5bf014dd9771ae741048e1f6c7748 +Author: Field G. Van Zee +Date: Fri Jul 19 14:42:37 2019 -0500 + + Updated -march flags for sandybridge, haswell. + + Details: + - Updated the '-march=corei7-avx' flag in the sandybridge subconfig + to '-march=sandybridge' and the '-march=core-avx2' flag in the + haswell subconfig to '-march=haswell'. The older flags were used + by older versions of gcc and should have been updated to the newer + forms a long time ago. (The older flags were clearly working, even + though they are no longer documented in the gcc man page.) + +commit c22b9dba5859a9fc94c8431eccc9e4eb9be02be1 +Author: Field G. Van Zee +Date: Tue Jul 16 13:14:47 2019 -0500 + + More updates to comments in testsuite modules. + + Details: + - Updated most comments in testsuite modules that describe how the + correctness test is performed so that it is clear whether the vector + (normfv) or matrix (normfm) form of Frobenius norm is used. + +commit c4cc6fa702f444a05963db01db51bc7d6669e979 +Author: Field G. Van Zee +Date: Tue Jul 16 13:00:35 2019 -0500 + + New cntx_t blksz "set" functions + misc tweaks. + + Details: + - Defined two new static functions in bli_cntx.h: + bli_cntx_set_blksz_def_dt() + bli_cntx_set_blksz_max_dt() + which developers may find convenient when experimenting with different + values of cache blocksizes. + - Updated one- and two-socket multithreaded problem size range and + increment values in test/3/Makefile. + - Changed default to column storage in test/3/test_gemm.c. + - Fixed typo in comment in testsuite/src/test_subm.c. + +commit b84cee29f42855dc1f263e42b83b1a46ac8def87 +Merge: 1f80858a c7dd6e6c +Author: Meghana Vankadari +Date: Mon Jul 8 02:03:07 2019 -0400 + + Merge "Added compiler flags for vanilla clang" into amd-staging-rome2.0 + +commit 1f80858abf5ca220b2998fbe6f9b06c32d3864c3 +Author: kdevraje +Date: Fri Jul 5 16:05:11 2019 +0530 + + This checkin solves the dgemm performance issue jira ticket CPUPL 458, as #else was missed during integration, it was always following else path to get the block sizes + + Change-Id: I0084b5856c2513ab1066c08c15b5086db6532717 + +commit c7dd6e6cd2f910cbefcdc1e04a5adeb919a23de0 +Author: Meghana +Date: Thu Jul 4 09:32:51 2019 +0530 + + Added compiler flags for vanilla clang + + Change-Id: I13c00b4c0d65bbda4c929848fd48b0ab611952ab + +commit 2acd49b76457635625a01e31c2abc8902b23cf51 +Author: Meghana +Date: Mon Jul 1 15:42:38 2019 +0530 + + fix for test failures using AOCC 2.0 + + Change-Id: If44eaccc64bbe96bbbe1d32279b1b5773aba08d1 + +commit ceee2f973ebe115beca55ca77f9e3ce36b14c28a +Author: Field G. Van Zee +Date: Mon Jun 24 17:47:40 2019 -0500 + + Fixed thrinfo_t printing bug for small problems. + + Details: + - Fixed a bug in bli_l3_thrinfo_print_gemm_paths() and + bli_l3_thrinfo_print_trsm_paths(), defined in bli_l3_thrinfo.c, + whereby subnodes of the thrinfo_t tree are "dereferenced" near the + beginning of the functions, which may lead to segfaults in certain + situations where the thread tree was not fully formed because the + matrix problem was too small for the level of parallelism specified. + (That is, too small because some problems were assigned no work due + to the smallest units in the m and n dimensions being defined by the + register blocksizes mr and nr.) The fix requires several nested levels + of if statements, and this is one of those few instances where use of + goto statements results in (mostly) prettier code, especially in the + case of _gemm_paths(). And while it wasn't necessary, I ported this + goto usage to the loop body that prints the thrinfo_t work_id and + comm_id values for each thread. Thanks to Nicholai Tukanov for helping + to find this bug. + +commit cac127182dd88ed0394ad81e6b91b897198e168a +Merge: 565fa385 3a45ecb1 +Author: kdevraje +Date: Mon Jun 24 13:01:27 2019 +0530 + + Merge branch 'amd-staging-rome2.0' of ssh://git.amd.com:29418/cpulibraries/er/blis + with public repo commit id 565fa3853b381051ac92cff764625909d105644d. + + Change-Id: I68b9824b110cf14df248217a24a6191b3df79d42 + +commit c152109e9a3b1cd74760e8a3215a676d25c18d2e +Author: Field G. Van Zee +Date: Wed Jun 19 13:23:24 2019 -0500 + + Updated BLASFEO results in PerformanceSmall.md. + + Details: + - Updated the BLASFEO performance graphs shown in PerformanceSmall.md + using a new commit of BLASFEO (2c9f312); updated PerformanceSmall.md + accordingly. + - Updated test/sup/octave/plot_l3sup_perf.m so that the .m files + containing the mpnpkp results do not need to be preprocessed in order + to plot half the problem size range (ie: up to 400 instead of the + 800 range of the other shape cases). + - Trivial updates to runme.m. + +commit 4d19c98110691d33ecef09d7e1b97bd1ccf4c420 +Author: Field G. Van Zee +Date: Sat Jun 8 11:02:03 2019 -0500 + + Trivial change to MixedDatatypes.md link text. + +commit 24965beabe83e19acf62008366097a7f198d4841 +Author: Field G. Van Zee +Date: Sat Jun 8 11:00:22 2019 -0500 + + Fixed typo in README.md's MixedDatatypes.md link. + +commit 50dc5d95760f41c5117c46f754245edc642b2179 +Author: Field G. Van Zee +Date: Fri Jun 7 13:10:16 2019 -0500 + + Adjust -fopenmp-simd for icc's preferred syntax. + + Details: + - Use -qopenmp-simd instead of -fopenmp-simd when compiling with Intel + icc. Recall that this option is used for SIMD auto-vectorization in + reference kernels only. Support for the -f option has been completely + deprecated and removed in newer versions of icc in favor of -q. Thanks + to Victor Eijkhout for reporting this issue and suggesting the fix. + +commit ad937db9507786874c801b41a4992aef42d924a1 +Author: Field G. Van Zee +Date: Fri Jun 7 11:34:08 2019 -0500 + + Added missing #include "bli_family_thunderx2.h". + + Details: + - Added a cpp-conditional directive block to bli_arch_config.h that + #includes "bli_family_thunderx2.h". The code has been missing since + adf5c17f. However, this never manifested as an error because the file + is virtually empty and not needed for thunderx2 (or most subconfigs). + Thanks to Jeff Diamond for helping to spot this. + +commit ce671917b2bc24895289247feef46f6fdd5020e7 +Author: Field G. Van Zee +Date: Thu Jun 6 14:17:21 2019 -0500 + + Fixed formatting/typo in docs/PerformanceSmall.md. + +commit 86c33a4eb284e2cf3282a1809be377785cdb3703 +Author: Field G. Van Zee +Date: Wed Jun 5 11:43:55 2019 -0500 + + Tweaked language in README.md related to sup/AMD. + +commit cbaa22e1ca368d36a8510f2b4ecd6f1523d1e1f3 +Author: Field G. Van Zee +Date: Tue Jun 4 16:06:58 2019 -0500 + + Added BLASFEO results to docs/PerformanceSmall.md. + + Details: + - Updated the graphs linked in PerformanceSmall.md with BLASFEO results, + and added documenting language accordingly. + - Updated scripts in test/sup/octave to plot BLASFEO data. + - Minor tweak to language re: how OpenBLAS was configured for + docs/Performance.md. + +commit 763fa39c3088c0e2c0155675a3ca868a58bffb30 +Author: Field G. Van Zee +Date: Tue Jun 4 14:46:45 2019 -0500 + + Minor tweaks to test/sup. + + Details: + - Changed starting problem and increment from 16 to 4. + - Added 'lll' (square problems) to list of problem size shapes to + compile and run with. + - Define BLASFEO location and added BLASFEO-related definitions. + +commit 5e1e696003c9151b1879b910a1957b7bdd7b0deb +Author: Field G. Van Zee +Date: Mon Jun 3 18:37:20 2019 -0500 + + CHANGELOG update (0.6.0) + +commit 18c876b989fd0dcaa27becd14e4f16bdac7e89b3 (tag: 0.6.0) +Author: Field G. Van Zee +Date: Mon Jun 3 18:37:19 2019 -0500 + + Version file update (0.6.0) + +commit 0f1b3bf49eb593ca7bb08b68a7209f7cd550f912 +Author: Field G. Van Zee +Date: Mon Jun 3 18:35:19 2019 -0500 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated ReleaseNotes.md in preparation for next version. + - CREDITS file update. + +commit 27da2e8400d900855da0d834b5417d7e83f21de1 +Author: Field G. Van Zee +Date: Mon Jun 3 17:14:56 2019 -0500 + + Minor edits to docs/PerformanceSmall.md. + + Details: + - Added performance analysis to "Comments" section of both Kaby Lake and + Epyc sections. + - Added emphasis to certain passages. + +commit 09ba05c6f87efbaadf085497dc137845f16ee9c5 +Author: Field G. Van Zee +Date: Mon Jun 3 16:53:19 2019 -0500 + + Added sup performance graphs/document to 'docs'. + + Details: + - Added a new markdown document, docs/PerformanceSmall.md, which + publishes new performance graphs for Kaby Lake and Epyc showcasing + the new BLIS sup (small/skinny/unpacked) framework logic and kernels. + For now, only single-threaded dgemm performance is shown. + - Reorganized graphs in docs/graphs into docs/graphs/large, with new + graphs being placed in docs/graphs/sup. + - Updates to scripts in test/sup/octave, mostly to allow decent output + in both GNU octave and Matlab. + - Updated README.md to mention and refer to the new PerformanceSmall.md + document. + +commit 6bf449cc6941734748034de0e9af22b75f1d6ba1 +Merge: abd8a9fa a4e8801d +Author: Field G. Van Zee +Date: Fri May 31 17:42:40 2019 -0500 + + Merge branch 'amd' + +commit a4e8801d08d81fa42ebea6a05a990de8dcedc803 +Author: Field G. Van Zee +Date: Fri May 31 17:30:51 2019 -0500 + + Increased MT sup threshold for double to 201. + + Details: + - Fine-tuned the double-precision real MT threshold (which controls + whether the sup implementation kicks for smaller m dimension values) + from 180 to 201 for haswell and 180 to 256 for zen. + - Updated octave scripts in test/sup/octave to include a seventh column + to display performance for m = n = k. + +commit 3a45ecb15456249c30ccccd60e42152f355615c1 +Merge: 3f867c96 b69fb0b7 +Author: Kiran Devrajegowda +Date: Fri May 31 06:47:02 2019 -0400 + + Merge "Added back BLIS_ENABLE_ZEN_BLOCK_SIZES macro to zen configuration, this is same as release 1.3. This was added before to improve DGEMM Multithreaded scalability on Naples for when number of threads is greater than 16. By mistake this got deleted in many changes done for 2.0 release, now we are adding this change back., in bli_gemm_front.c - code cleanup" into amd-staging-rome2.0 + +commit b69fb0b74a4756168de270fc9b18f7cf7aa57f17 +Author: Kiran Varaganti +Date: Fri May 31 15:14:22 2019 +0530 + + Added back BLIS_ENABLE_ZEN_BLOCK_SIZES macro to zen configuration, this is same as release 1.3. This was added before to improve DGEMM Multithreaded scalability on Naples for when number of threads is greater than 16. By mistake this got deleted in many changes done for 2.0 release, now we are adding this change back., in bli_gemm_front.c - code cleanup + + Change-Id: I9f5d8225254676a99c6f2b09a0825e545206d0fc + +commit 3f867c96caea3bbbbeeff1995d90f6cf8c9895fb +Author: kdevraje +Date: Fri May 31 12:22:44 2019 +0530 + + When running HPL with pure MPI without DGEMM Threading (Single Threaded BLIS ), making this macro 1 gives best performance.wq + + Change-Id: I24fd0bf99216f315e49f1c74c44c3feaffd7078d + +commit abd8a9fa7df4569aa2711964c19888b8e248901f (origin/pfhp) +Author: Field G. Van Zee +Date: Tue May 28 12:49:44 2019 -0500 + + Inadvertantly hidden xerbla_() in blastest (#313). + + Details: + - Attempted a fix to issue #313, which reports that when building only + a shared library (ie: static library build is disabled), running the + BLAS test drivers can fail because those drivers provide their own + local version of xerbla_() as a clever (albeit still rather hackish) + way of checking the error codes that result from the individual tests. + This local xerbla_() function is never found at link-time because the + BLAS test drivers' Makefile imports BLIS compilation flags via the + get-user-cflags-for() function, which currently conveys the + -fvisibility=hidden flag, which hides symbols unless they are + explicitly annotated for export. The -fvisibility=hidden flag was + only ever intended for use when building BLIS (not for applications), + and so the attempted solution here is to omit the symbol export + flag(s) from get-user-cflags-for() by storing the symbol export + flag(s) to a new BULID_SYMFLAGS variable instead of appending it + to the subconfigurations' CMISCFLAGS variable (which is returned by + every get-*-cflags-for() function). Thanks to M. Zhou for reporting + this issue and also to Isuru Fernando for suggesting the fix. + - Renamed BUILD_FLAGS to BUILD_CPPFLAGS to harmonize with the newly + created BUILD_SYMFLAGS. + - Fixed typo in entry for --export-shared flag in 'configure --help' + text. + +commit 13806ba3b01ca0dd341f4720fb930f97e46710b0 +Author: kdevraje +Date: Mon May 27 16:24:43 2019 +0530 + + This check in has changes w.r.t Copyright information, which is changed to (start year) - 2019 + + Change-Id: Ide3c8f7172210b8d3538d3c36e88634ab1ba9041 + +commit ee123f535872510f77100d3d55a43d4ca56047d5 +Author: Meghana +Date: Mon May 27 15:36:44 2019 +0530 + + Defined small matrix thresholds for TRSM for various cases for NAPLES and ROME + Updated copyright information for kernels/zen/bli_trsm_small.c file + Removed separate kernels for zen2 architecture + Instead added threshold conditions in zen kernels both for ROME and NAPLES + + Change-Id: Ifd715731741d649b6ad16b123a86dbd6665d97e5 + +commit 9d93a4caa21402d3a90aac45d7a1603736c9fd63 +Author: prangana +Date: Fri May 24 17:59:13 2019 +0530 + + update version 2.0 + +commit 755730608d923538273a90c48bfdf77571f86519 +Author: Field G. Van Zee +Date: Thu May 23 17:34:36 2019 -0500 + + Minor rewording of language around mt env. vars. + +commit ba31abe73c97c16c78fffc59a215761b8d9fd1f6 +Author: Field G. Van Zee +Date: Thu May 23 14:59:53 2019 -0500 + + Added BLIS theading info to Performance.md. + + Details: + - Documented the BLIS environment variables that were set + (e.g. BLIS_JC_NT, BLIS_IC_NT, BLIS_JR_NT) for each machine and + threading configuration in order to achieve the parallelism reported + on in docs/Performance.md. + +commit cb788ffc89cac03b44803620412a5e83450ca949 +Author: Field G. Van Zee +Date: Thu May 23 13:00:53 2019 -0500 + + Increased MT sup threshold for double to 180. + + Details: + - Increased the double-precision real MT threshold (which controls + whether the sup implementation kicks for smaller m dimension values) + from 80 to 180, and this change was made for both haswell and zen + subconfigurations. This is less about the m dimension in particular + and more about facilitating a smoother performance transition when + m = n = k. + +commit 057f5f3d211e7513f457ee6ca6c9555d00ad1e57 +Author: Field G. Van Zee +Date: Thu May 23 12:51:17 2019 -0500 + + Minor build system housekeeping. + + Details: + - Commented out redundant setting of LIBBLIS_LINK within all driver- + level Makefiles. This variable is already set within common.mk, and + so the only time it should be overridden is if the user wants to link + to a different copy of libblis. + - Very minor changes to build/gen-make-frags/gen-make-frag.sh. + - Whitespace and inconsequential quoting change to configure. + - Moved top-level 'windows' directory into a new 'attic' directory. + +commit e05171118c377f356f89c4daf8a0d5ddc5a4e4f7 +Author: Meghana +Date: Thu May 23 16:15:27 2019 +0530 + + Implemented TRSM for small matrices for cases where A is on the right + + Added separate kernels for zen and zen2 + + Change-Id: I6318ddc250cf82516c1aa4732718a35eae0c9134 + +commit 02920f5c480c42706b487e37b5ecc96c3555b851 +Author: kdevraje +Date: Thu May 23 15:29:59 2019 +0530 + + make checkblis fails for matrix dimension check at the begining hence reverting it + + Change-Id: Ibd2ee8c2d4914598b72003fbfc5845be9c9c1e87 + +commit 84215022f29fb3bfedd254d041635308d177e6c0 +Author: kdevraje +Date: Thu May 23 11:08:41 2019 +0530 + + Adding threshold condition to dgemm small matrix kernels, defining the constants in zen2 configuration + + Change-Id: I53a58b5d734925a6fcb8d8bea5a02ddb8971fcd5 + +commit a3554eb1dcc1b5b94d81c60761b2f01c3d827ffa +Merge: ea082f83 17b878b6 +Author: kdevraje +Date: Thu May 23 11:51:07 2019 +0530 + + Merge branch 'amd-staging-rome2.0' of ssh://git.amd.com:29418/cpulibraries/er/blis to configure zen2 + + Change-Id: I97e17bca9716b80b862925f97bb513c07b4b0cae + +commit ea082f839071dd9ec555062dc3851c31d12f00e4 +Author: kdevraje +Date: Thu May 23 10:38:29 2019 +0530 + + adding empty zen2 directory with .gitignore file + + Change-Id: Ifa37cf54b2578aa19ad335372b44bca17043fe4b + +commit b80bd5bcb2be8551a9a21fafc8e6c8b6336c99b5 +Author: Kiran Varaganti +Date: Tue May 21 15:11:47 2019 +0530 + + config/zen/bli_cntx_init_zen.c: removed BLIS_ENBLE_ZEN_BLOCK_SIZES macro. We have different configurations for both zen and zen2 + config/zen/bli_family_zen.h: deleted macro BLIS_ENBLE_ZEN_BLOCK_SIZES + config/zen/make_defs.mk: removed compiler flag -mno-avx256-split-unaligned-store + frame/base/bli_cpuid.c: ROME family is 17H but model # is from 0x30H. + test/test_gemm.c - commented out #define FILE_IN_OUT (some compilation error when BLIS is configured as amd64) + Now we can use single configuration has ./configure amd64 - this will work both for ROME & Naples + + Change-Id: I91b4fc35380f8a35b4f4c345da040c6b5910b4a2 + +commit a042db011df9a1c3e7c7ac546541f4746b176ea5 +Author: Kiran Varaganti +Date: Mon May 20 14:17:32 2019 +0530 + + Modified make_defs.mk for zen2 to get compiled by gcc version less than gcc9.0 + + Change-Id: I8fcac30538ee39534c296932639053b47b9a2d43 + +commit a23f92594cf3d530e5794307fe97afc877d853b7 +Author: Kiran Varaganti +Date: Mon May 20 10:48:06 2019 +0530 + + config_registry: New AMD zen2 architecture configuration added. + frame/base/bli_arch.c: #ifdef BLIS_FAMILY_ZEN2 id = BLIS_ARCH_ZEN2; #endif added. zen2 is added in config_name[BLIS_NUM_ARCHS] + frame/base/bli_cpuid.c : #ifdef BLIS_CONFIG_ZEN2 if ( bli_cpuid_is_zen2( family, model, features ) ) return BLIS_ARCH_ZEN2; #endif, defined new function bool bli_cpuid_is_zen2(...). + frame/base/bli_cpuid.h : declared bli_cpuid_is_zen2(..). + frame/base/bli_gks.c : #ifdef BLIS_CONFIG_ZEN2 bli_gks_register_cntx(BLIS_ARCH_ZEN2, bli_cntx_init_zen2, bli_cntx_init_zen2_ref, bli_cntx_init_zen2_ind); #endif + frame/include/bli_arch_config.h : #ifdef BLIS_CONFIG_ZEN2 CNTX_INIT_PROTS(zen2) #endif #ifdef BLIS_FAMILY_ZEN2 #include "bli_family_zen2.h" #endif + frame/include/bli_type_defs.h : added BLIS_ARCH_ZEN2 in arch_t enum. BLIS_NUM_ARCHS 20 + + Change-Id: I2a2d9b7266673e78a4f8543b1bfb5425b0aa7866 + +commit 17b878b66d917d50b6fe23721d8579e826cb3e8c +Author: kdevraje +Date: Wed May 22 14:02:53 2019 +0530 + + adding license same as in ut-austin-amd-branch + + Change-Id: I6790768d2bf5d42369d304ef93e34701f95fbaff + +commit df755848b8a271323e007c7a628c64af63deab00 +Merge: ca4b33c0 c72ae27a +Author: kdevraje +Date: Wed May 22 13:30:07 2019 +0530 + + Merge branch 'amd-staging-rome2.0' of ssh://git.amd.com:29418/cpulibraries/er/blis into rome2.0 + + Change-Id: Ie8aad1ab810f0f3c0b90ec67f9dd3dfb8dcc74cc + +commit c72ae27adee4726679ee004d02c972582b5285b4 +Author: Nisanth M P +Date: Mon Mar 19 12:49:26 2018 +0530 + + Re-enabling the small matrix gemm optimization for target zen + + Change-Id: I13872784586984634d728cd99a00f71c3f904395 + +commit ab0818af80f7f683080873f3fa24734b65267df2 +Author: sraut +Date: Wed Oct 3 15:30:33 2018 +0530 + + Review comments incorporated for small TRSM. + + Change-Id: Ia64b7b2c0375cc501c2cb0be8a1af93111808cd9 + +commit 32392cfc72af7f42da817a129748349fb1951346 +Author: Jeff Hammond +Date: Tue May 14 15:52:30 2019 -0400 + + add info about CXX in configure (#311) + +commit fa7e6b182b8365465ade178b0e4cd344ff6f6460 +Author: Field G. Van Zee +Date: Wed May 1 19:13:00 2019 -0500 + + Define _POSIX_C_SOURCE in bli_system.h. + + Details: + - Added + #ifndef _POSIX_C_SOURCE + #define _POSIX_C_SOURCE 200809L + #endif + to bli_system.h so that an application that uses BLIS (specifically, + an application that #includes blis.h) does not need to remember to + #define the macro itself (either on the command line or in the code + that includes blis.h) in order to activate things like the pthreads. + Thanks to Christos Psarras for reporting this issue and suggesting + this fix. + - Commented out #include in bli_system.h, since I don't + think this header is used/needed anymore. + - Comment update to function macro for bli_?normiv_unb_var1() in + frame/util/bli_util_unb_var1.c. + +commit 3df84f1b5d5e1146bb01bfc466ac20c60a9cc859 +Author: Field G. Van Zee +Date: Sat Apr 27 21:27:32 2019 -0500 + + Minor bugfixes in sup dgemm implementation. + + Details: + - Fixed an obscure but in the bli_dgemmsup_rv_haswell_asm_5x8n() kernel + that only affected the beta == 0, column-storage output case. Thanks + to the BLAS test drivers for catching this bug. + - Previously, bli_gemmsup_ref_var1n() and _var2m() were returning if + k = 0, when the correct action would be to scale by beta (and then + return). Thanks to the BLAS test drivers to catching this bug. + - Changed the sup threshold behavior such that the sup implementation + only kicks in if a matrix dimension is strictly less than (rather than + less than or equal to) the threshold in question. + - Initialize all thresholds to zero (instead of 10) by default in + ref_kernels/bli_cntx_ref.c. This, combined with the above change to + threshold testing means that calls to BLIS or BLAS with one or more + matrix dimensions of zero will no longer trigger the sup + implementation. + - Added disabled debugging output to frame/3/bli_l3_sup.c (for future + use, perhaps). + +commit ecbdd1c42dcebfecd729fe351e6bb0076aba7d81 +Author: Field G. Van Zee +Date: Sat Apr 27 19:38:11 2019 -0500 + + Ceased use of BLIS_ENABLE_SUP_MR/NR_EXT macros. + + Details: + - Removed already limited use of the BLIS_ENABLE_SUP_MR_EXT and + BLIS_ENABLE_SUP_NR_EXT macros in bli_gemmsup_ref_var1n() and + bli_gemmsup_ref_var2m(). Their purpose was merely to avoid a long + conditional that would determine whether to allow the last iteration + to be merged with the second-to-last iteration. Functionally, the + macros were not needed, and they ended up causing problems when + building configuration families such as intel64 and x86_64. + +commit aa8a6bec3036a41e1bff2034f8ef6766a704ec49 +Author: Field G. Van Zee +Date: Sat Apr 27 18:53:33 2019 -0500 + + Fixed typo in --disable-sup-handling macro guard. + + Details: + - Fixed an incorrectly-named macro guard that is intended to allow + disabling of the sup framework via the configure option + --disable-sup-handling. In this case, the preprocessor macro, + BLIS_DISABLE_SUP_HANDLING, was still named by its name from an older + uncommitted version of the code (BLIS_DISABLE_SM_HANDLING). + +commit b9c9f03502c78a63cfcc21654b06e9089e2a3822 +Author: Field G. Van Zee +Date: Sat Apr 27 18:44:50 2019 -0500 + + Implemented gemm on skinny/unpacked matrices. + + Details: + - Implemented a new sub-framework within BLIS to support the management + of code and kernels that specifically target matrix problems for which + at least one dimension is deemed to be small, which can result in long + and skinny matrix operands that are ill-suited for the conventional + level-3 implementations in BLIS. The new framework tackles the problem + in two ways. First the stripped-down algorithmic loops forgo the + packing that is famously performed in the classic code path. That is, + the computation is performed by a new family of kernels tailored + specifically for operating on the source matrices as-is (unpacked). + Second, these new kernels will typically (and in the case of haswell + and zen, do in fact) include separate assembly sub-kernels for + handling of edge cases, which helps smooth performance when performing + problems whose m and n dimension are not naturally multiples of the + register blocksizes. In a reference to the sub-framework's purpose of + supporting skinny/unpacked level-3 operations, the "sup" operation + suffix (e.g. gemmsup) is typically used to denote a separate namespace + for related code and kernels. NOTE: Since the sup framework does not + perform any packing, it targets row- and column-stored matrices A, B, + and C. For now, if any matrix has non-unit strides in both dimensions, + the problem is computed by the conventional implementation. + - Implemented the default sup handler as a front-end to two variants. + bli_gemmsup_ref_var2() provides a block-panel variant (in which the + 2nd loop around the microkernel iterates over n and the 1st loop + iterates over m), while bli_gemmsup_ref_var1() provides a panel-block + variant (2nd loop over m and 1st loop over n). However, these variants + are not used by default and provided for reference only. Instead, the + default sup handler calls _var2m() and _var1n(), which are similar + to _var2() and _var1(), respectively, except that they defer to the + sup kernel itself to iterate over the m and n dimension, respectively. + In other words, these variants rely not on microkernels, but on + so-called "millikernels" that iterate along m and k, or n and k. + The benefit of using millikernels is a reduction of function call + and related (local integer typecast) overhead as well as the ability + for the kernel to know which micropanel (A or B) will change during + the next iteration of the 1st loop, which allows it to focus its + prefetching on that micropanel. (In _var2m()'s millikernel, the upanel + of A changes while the same upanel of B is reused. In _var1n()'s, the + upanel of B changes while the upanel of A is reused.) + - Added a new configure option, --[en|dis]able-sup-handling, which is + enabled by default. However, the default thresholds at which the + default sup handler is activated are set to zero for each of the m, n, + and k dimensions, which effectively disables the implementation. (The + default sup handler only accepts the problem if at least one dimension + is smaller than or equal to its corresponding threshold. If all + dimensions are larger than their thresholds, the problem is rejected + by the sup front-end and control is passed back to the conventional + implementation, which proceeds normally.) + - Added support to the cntx_t structure to track new fields related to + the sup framework, most notably: + - sup thresholds: the thresholds at which the sup handler is called. + - sup handlers: the address of the function to call to implement + the level-3 skinny/unpacked matrix implementation. + - sup blocksizes: the register and cache blocksizes used by the sup + implementation (which may be the same or different from those used + by the conventional packm-based approach). + - sup kernels: the kernels that the handler will use in implementing + the sup functionality. + - sup kernel prefs: the IO preference of the sup kernels, which may + differ from the preferences of the conventional gemm microkernels' + IO preferences. + - Added a bool_t to the rntm_t structure that indicates whether sup + handling should be enabled/disabled. This allows per-call control + of whether the sup implementation is used, which is useful for test + drivers that wish to switch between the conventional and sup codes + without having to link to different copies of BLIS. The corresponding + accessor functions for this new bool_t are defined in bli_rntm.h. + - Implemented several row-preferential gemmsup kernels in a new + directory, kernels/haswell/3/sup. These kernels include two general + implementation types--'rd' and 'rv'--for the 6x8 base shape, with + two specialized millikernels that embed the 1st loop within the kernel + itself. + - Added ref_kernels/3/bli_gemmsup_ref.c, which provides reference + gemmsup microkernels. NOTE: These microkernels, unlike the current + crop of conventional (pack-based) microkernels, do not use constant + loop bounds. Additionally, their inner loop iterates over the k + dimension. + - Defined new typedef enums: + - stor3_t: captures the effective storage combination of the level-3 + problem. Valid values are BLIS_RRR, BLIS_RRC, BLIS_RCR, etc. A + special value of BLIS_XXX is used to denote an arbitrary combination + which, in practice, means that at least one of the operands is + stored according to general stride. + - threshid_t: captures each of the three dimension thresholds. + - Changed bli_adjust_strides() in bli_obj.c so that bli_obj_create() + can be passed "-1, -1" as a lazy request for row storage. (Note that + "0, 0" is still accepted as a lazy request for column storage.) + - Added support for various instructions to bli_x86_asm_macros.h, + including imul, vhaddps/pd, and other instructions related to integer + vectors. + - Disabled the older small matrix handling code inserted by AMD in + bli_gemm_front.c, since the sup framework introduced in this commit + is intended to provide a more generalized solution. + - Added test/sup directory, which contains standalone performance test + drivers, a Makefile, a runme.sh script, and an 'octave' directory + containing scripts compatible with GNU Octave. (They also may work + with matlab, but if not, they are probably close to working.) + - Reinterpret the storage combination string (sc_str) in the various + level-3 testsuite modules (e.g. src/test_gemm.c) so that the order + of each matrix storage char is "cab" rather than "abc". + - Comment updates in level-3 BLAS API wrappers in frame/compat. + +commit 0d549ceda822833bec192bbf80633599620c15d9 +Author: Isuru Fernando +Date: Sat Apr 27 22:56:02 2019 +0000 + + make unix friendly archives on appveyor (#310) + +commit ca4b33c001f9e959c43b95a9a23f9df5adec7adf +Author: Kiran Varaganti +Date: Wed Apr 24 15:02:39 2019 +0530 + + Added compiler option (-mno-avx256-split-unaligned-store) in the file config/zen/make_defs.mk to improve performance of intrinsic codes, this flag ensures compiler generates 256-bit stores for the equivalent intrinsics code. + + Change-Id: I8f8cd81a3604869df18d38bc42097a04f178d324 + +commit 945928c650051c04d6900c7f4e9e29cd0e5b299f +Merge: 663f6629 74e513eb +Author: Field G. Van Zee +Date: Wed Apr 17 15:58:56 2019 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 74e513eb6a6787a925d43cd1500277d54d86ab8f +Author: Field G. Van Zee +Date: Wed Apr 17 13:34:44 2019 -0500 + + Support row storage in Eigen gemm test/3 driver. + + Details: + - Added preprocessor branches to test/3/test_gemm.c to explicitly + support row-stored matrices. Column-stored matrices are also still + supported (and is the default for now). (This is mainly residual work + leftover from initial integration of Eigen into the test drivers, so + if we ever want to test Eigen with row-stored matrices, the code will + be ready to use, even if it is not yet integrated into the Makefile + in test/3.) + +commit b5d457fae9bd75c4ca67f7bc7214e527aa248127 +Author: Field G. Van Zee +Date: Tue Apr 16 12:50:01 2019 -0500 + + Applied forgotten variable rename from 89a70cc. + + Details: + - Somehow the variable name change (root_file_name -> root_inputname) + in flatten-headers.py mentioned in the commit log entry for 89a70cc + didn't make it into the actual commit. This commit applies that + change. + +commit 89a70cccf869333147eb2559cdfa5a23dc915824 +Author: Field G. Van Zee +Date: Thu Apr 11 18:33:08 2019 -0500 + + GNU-like handling of installation prefix et al. + + Details: + - Changed the default installation prefix from $HOME/lib to /usr/local. + - Modified the way configure internally handles the prefix, libdir, + includedir, and sharedir (and also added an --exec-prefix option). + The defaults to these variables are set as follows: + prefix: /usr/local + exec_prefix: ${prefix} + libdir: ${exec_prefix}/lib + includedir: ${prefix}/include + sharedir: ${prefix}/share + The key change, aside from the addition of exec_prefix and its use to + define the default to libdir, is that the variables are substituted + into config.mk with quoting that delays evaluation, meaning the + substituted values may contain unevaluated references to other + variables (namely, ${prefix} and ${exec_prefix}). This more closely + follows GNU conventions, including those used by GNU autoconf, and + also allows make to override any one of the variables *after* + configure has already been run (e.g. during 'make install'). + - Updates to build/config.mk.in pursuant to above changes. + - Updates to output of 'configure --help' pursuant to above changes. + - Updated docs/BuildSystem.md to reflect the new default installation + prefix, as well as mention EXECPREFIX and SHAREDIR. + - Changed the definitions of the UNINSTALL_OLD_* variables in the + top-level Makefile to use $(wildcard ...) instead of 'find'. This + was motivated by the new way of handling prefix and friends, which + leads to the 'find' command being run on /usr/local (by default), + which can take a while almost never yielding any benefit (since the + user will very rarely use the uninstall-old targets). + - Removed periods from the end of descriptive output statements (i.e., + non-verbose output) since those statements often end with file or + directory paths, which get confusing to read when puctuated by a + period. + - Trival change to 'make showconfig' output. + - Removed my name from 'configure --help'. (Many have contributed to it + over the years.) + - In configure script, changed the default state of threading_model + variable from 'no' to 'off' to match that of debug_type, where there + are similarly more than two valid states. ('no' is still accepted + if given via the --enable-debug= option, though it will be + standardized to 'off' prior to config.mk being written out.) + - Minor variable name change in flatten-headers.py that was intended for + 32812ff. + - CREDITS file update. + +commit 9d76688ad90014a11ddc0c2f27253d62806216b1 +Author: kdevraje +Date: Thu Apr 11 10:22:48 2019 +0530 + + Fix for single rank crash with HPL application. When computing offset of C buffer, as integer variables are used for a row and column index, the intermediate result value overflows and a negative value gets added to the buffer, when the negative value is too large it would index the buffer out of the range resulting in segmentation fault. Although the crash is a result of dgemm kernel, added similar code in sgemm kernel also. + + Change-Id: I171119b0ec0dfbd8e63f1fcd6609a94384aabd27 + +commit 32812ff5aba05d34c421fe1024a61f3e2d5e7052 +Author: Field G. Van Zee +Date: Tue Apr 9 12:20:19 2019 -0500 + + Minor bugfix to flatten-headers.py. + + Details: + - Fixed a minor bug in flatten-headers.py whereby the script, upon + encountering a #include directive for the root header file, would + erroneously recurse and inline the conents of that root header. + The script has been modified to avoid recursion into any headers + that share the same name as the root-level header that was passed + into the script. (Note: this bug didn't actually manifest in BLIS, + so it's merely a precaution for usage of flatten-headers.py in other + contexts.) + +commit bec90e0b6aeb3c9b19589c2b700fda2d66f6ccdf +Author: Field G. Van Zee +Date: Tue Apr 2 17:45:13 2019 -0500 + + Minor update to docs/HardwareSupport.md document. + + Details: + - Added more details and clarifying language to implications of 1m and + the recycling of microkernels between microarchitectures. + +commit 89cd650e7be01b59aefaa85885a3ea78970351e4 +Author: Field G. Van Zee +Date: Tue Apr 2 17:23:55 2019 -0500 + + Use void_fp for function pointers instead of void*. + + Change void*-typed function pointers to void_fp. + - Updated all instances of void* variables that store function pointers + to variables of a new type, void_fp. Originally, I wanted to define + the type of void_fp as "void (*void_fp)( void )"--that is, a pointer + to a function with no return value and no arguments. However, once + I did this, I realized that gcc complains with incompatible pointer + type (-Wincompatible-pointer-types) warnings every time any such a + pointer is being assigned to its final, type-accurate function + pointer type. That is, gcc will silently typecast a void* to + another defined function pointer type (e.g. dscalv_ker_ft) during + an assignment from the former to the latter, but the same statement + will trigger a warning when typecasting from a void_fp type. I suspect + an explicit typecast is needed in order to avoid the warning, which + I'm not willing to insert at this time. + - Added a typedef to bli_type_defs.h defining void_fp as void*, along + with a commented-out version of the aborted definition described + above. (Note that POSIX requires that void* and function pointers + be interchangeable; it is the C standard that does not provide this + guarantee.) + - Comment updates to various _oapi.c files. + +commit ffce3d632b284eb52474036096815ec38ca8dd5f +Author: Field G. Van Zee +Date: Tue Apr 2 14:40:50 2019 -0500 + + Renamed armv8a gemm kernel filename. + + Details: + - Renamed + kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c + to + kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c. + This follows the naming convention used by other kernel sets, most + notably haswell. + +commit 77867478af02144544b4e7b6df5d54d874f3f93b +Author: Isuru Fernando +Date: Tue Apr 2 13:33:11 2019 -0500 + + Use pthreads on MinGW and Cygwin (#307) + +commit 7bc75882f02ce3470a357950878492e87e688cec +Author: Field G. Van Zee +Date: Thu Mar 28 17:40:50 2019 -0500 + + Updated Eigen results in docs/graphs with 3.3.90. + + Details: + - Updated the level-3 performance graphs in docs/graphs with new Eigen + results, this time using a development version cloned from their git + mirror on March 27, 2019 (version 3.3.90). Performance is improved + over 3.3.7, though still noticeably short of BLIS/MKL in most cases. + - Very minor updates to docs/Performance.md and matlab scripts in + test/3/matlab. + +commit 20ea7a1217d3833db89a96158c42da2d6e968ed8 +Author: Field G. Van Zee +Date: Wed Mar 27 18:09:17 2019 -0500 + + Minor text updates (Eigen) to docs/Performance.md. + + Details: + - Added/updated a few more details, mostly regarding Eigen. + +commit bfb7e1bc6af468e4ff22f7e27151ea400dcd318a +Merge: 044df950 2c85e1dd +Author: Field G. Van Zee +Date: Wed Mar 27 17:58:19 2019 -0500 + + Merge branch 'dev' + +commit 2c85e1dd9d5d84da7228ea4ae6deec56a89b3a8f +Author: Field G. Van Zee +Date: Wed Mar 27 16:29:51 2019 -0500 + + Added Eigen results to performance graphs. + + Details: + - Updated the Haswell, SkylakeX, and Epyc performance graphs in + docs/graphs to report on Eigen implementations, where applicable. + Specifically, Eigen implements all level-3 operations sequentially, + however, of those operations it only provides multithreaded gemm. + Thus, mt results for symm/hemm, syrk/herk, trmm, and trsm are + omitted. Thanks to Sameer Agarwal for his help configuring and + using Eigen. + - Updated docs/Performance.md to note the new implementation tested. + - CREDITS file update. + +commit bfac7e385f8061f2e6591de208b0acf852f04580 +Author: Field G. Van Zee +Date: Wed Mar 27 16:04:48 2019 -0500 + + Added ability to plot with Eigen in test/3/matlab. + + Details: + - Updated matlab scripts in test/3/matlab to optionally plot/display + Eigen performance curves. Whether Eigen is plotted is determined by + a new boolean function parameter, with_eigen. + - Updated runme.m scratchpad to reflect the latest invocations of the + plot_panel_4x5() function (with Eigen plotting enabled). + +commit 67535317b9411c90de7fa4cb5b0fdb8f61fdcd79 +Author: Field G. Van Zee +Date: Wed Mar 27 13:32:18 2019 -0500 + + Fixed mislabeled eigen output from test/3 drivers. + + Details: + - Fixed the Makefile in test/3 so that it no longer incorrectly labels + the matlab output variables from Eigen-linked hemm, herk, trmm, and + trsm driver output as "vendor". (The gemm drivers were already + correctly outputing matlab variables containing the "eigen" label.) + +commit 044df9506f823643c0cdd53e81ad3c27a9f9d4ff +Author: Isuru Fernando +Date: Wed Mar 27 12:39:31 2019 -0500 + + Test with shared on windows (#306) + + Export macros can't support both shared and static at the same time. + When blis is built with both shared and static, headers assume that + shared is used at link time and dllimports the symbols with __imp_ + prefix. + + To use the headers with static libraries a user can give + -DBLIS_EXPORT= to import the symbol without the __imp_ prefix + +commit 5e6b160c8a85e5e23bab0f64958a8acf4918a4ed +Author: Field G. Van Zee +Date: Tue Mar 26 19:10:59 2019 -0500 + + Link to Eigen BLAS for non-gemm drivers in test/3. + + Details: + - Adjusted test/3/Makefile so that the test drivers are linked against + Eigen's BLAS library for hemm, herk, trmm, and trsm. We have to do + this since Eigen's headers don't define implementations to the + standard BLAS APIs. + - Simplified #included headers in hemm, herk, trmm, and trsm source + driver files, since nothing specific to Eigen is needed at + compile-time for those operations. + +commit e593221383aae19dfdc3f30539de80ed05cfec7f +Merge: 92fb9c87 c208b9dc +Author: Field G. Van Zee +Date: Tue Mar 26 15:51:45 2019 -0500 + + Merge branch 'master' into dev + +commit 92fb9c87bf88b9f9c401eeecd9aa9c3521bc2adb +Author: Field G. Van Zee +Date: Tue Mar 26 15:43:23 2019 -0500 + + Add more support for Eigen to drivers in test/3. + + Details: + - Use compile-time implementations of Eigen in test_gemm.c via new + EIGEN cpp macro, defined on command line. (Linking to Eigen's BLAS + library is not necessary.) However, as of Eigen 3.3.7, Eigen only + parallelizes the gemm operation and not hemm, herk, trmm, trsm, or + any other level-3 operation. + - Fixed a bug in trmm and trsm drivers whereby the wrong function + (bli_does_trans()) was being called to determine whether the object + for matrix A should be created for a left- or right-side case. This + was corrected by changing the function to bli_is_left(), as is done + in the hemm driver. + - Added support for running Eigen test drivers from runme.sh. + +commit c208b9dc46852c877197d53b6dd913a046b6ebb6 +Author: Isuru Fernando +Date: Mon Mar 25 13:03:44 2019 -0500 + + Fix clang version detection (#305) + + clang -dumpversion gives 4.2.1 for all clang versions as clang was + originally compatible with gcc 4.2.1 + + Apple clang version and clang version are two different things + and the real clang version cannot be deduced from apple clang version + programatically. Rely on wikipedia to map apple clang to clang version + + Also fixes assembly detection with clang + + clang 3.8 can't build knl as it doesn't recognize zmm0 + +commit 53842c7e7d530cb2d5609d6d124ae350fc345c32 +Author: Kiran Varaganti +Date: Fri Mar 22 13:57:14 2019 +0530 + + Removed printing alpha and beta values + + Change-Id: I49102db510311a30f6a936f9d843f35838f50d23 + +commit 6805db45e343d83d1adaf9157cf0b841653e9ede +Author: Kiran Varaganti +Date: Fri Mar 22 12:55:35 2019 +0530 + + Corrected setting alpha & beta values- alpha = -1 and beta = 1 - bli_setc(-1.0, 0, &alpha) should be used rather than bli_setc(0.0, -1.0, &alpha). This corrected now + + Change-Id: Ic1102dfd6b50ccf212386a1211c6f31e8d987ef9 + +commit feefcab4427a75b0b55af215486b85abcda314f7 +Author: Field G. Van Zee +Date: Thu Mar 21 18:11:20 2019 -0500 + + Allow disabling of BLAS prototypes at compile-time. + + Details: + - Modified bli_blas.h so that: + - By default, if the BLAS layer is enabled at configure-time, BLAS + prototypes are also enabled within blis.h; + - But if the user #defines BLIS_DISABLE_BLAS_DEFS prior to including + blis.h, BLAS prototypes are skipped over entirely so that, for + example, the application or some other header pulled in by the + application may prototype the BLAS functions without causing any + duplication. + - Updated docs/BuildSystem.md to document the feature above, and + related text. + +commit 20153cd4b594bc34f860c381ec18de3a6cc743c7 +Author: Kiran Varaganti +Date: Thu Mar 21 16:23:53 2019 +0530 + + Modified test_gemm.c file in test folder + A Macro 'FILE_IN_OUT" is defined to read input parameters from a csv file. + Format for input file: + Each line defines a gemm problem with following parameters: m k n cs_a cs_b cs_c + The operation always implemented is C = C - A*B and column-major format. + When macro is disabled - it reverts back to original implementation. + Usage: ./test_gemm_.x input.csv output.csv + GEMM is called through BLAS interface + For BLIS - the test application also prints either 'S' indicating small gemm routine or 'N' - conventional BLIS gemm + for MKL/OpenBLAS - ignore this character + + Change-Id: I0924ef2c1f7bdea48d4cdb230b888e2af2c86a36 + +commit 288843b06d91e1b4fade337959aef773090bd1c9 +Author: Field G. Van Zee +Date: Wed Mar 20 17:52:23 2019 -0500 + + Added Eigen support to test/3 Makefile, runme.sh. + + Details: + - Added targets to test/3/Makefile that link against a BLAS library + build by Eigen. It appears, however, that Eigen's BLAS library does + not support multithreading. (It may be that multithreading is only + available when using the native C++ APIs.) + - Updated runme.sh with a few Eigen-related tweaks. + - Minor tweaks to docs/Performance.md. + +commit 153e0be21d9ff413e370511b68d553dd02abada9 +Author: Field G. Van Zee +Date: Tue Mar 19 17:53:18 2019 -0500 + + More minor tweaks to docs/Performance.md. + + Details: + - Defined GFLOPS as billions of floating-point operations per second, + and reworded the sentence after about normalization. + +commit 05c4e42642cc0c8dbfa94a6c21e975ac30c0517a +Author: Field G. Van Zee +Date: Tue Mar 19 17:07:20 2019 -0500 + + CHANGELOG update (0.5.2) + +commit 9204cd0cb0cc27790b8b5a2deb0233acd9edeb9b (tag: 0.5.2) +Author: Field G. Van Zee +Date: Tue Mar 19 17:07:18 2019 -0500 + + Version file update (0.5.2) + +commit 64560cd9248ebf4c02c4a1eeef958e1ca434e510 +Author: Field G. Van Zee +Date: Tue Mar 19 17:04:20 2019 -0500 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated ReleaseNotes.md in preparation for next version. + +commit ab5ad557ea69479d487c9a3cb516f43fa1089863 +Author: Field G. Van Zee +Date: Tue Mar 19 16:50:41 2019 -0500 + + Very minor tweaks to Performance.md. + +commit 03c4a25e1aa8a6c21abbb789baa599ac419c3641 +Author: Field G. Van Zee +Date: Tue Mar 19 16:47:15 2019 -0500 + + Minor fixes to docs/Performance.md. + + Details: + - Fixed some incorrect labels associated with the pdf/png graphs, + apparently the result of copy-pasting. + +commit fe6dd8b132f39ecb8893d54cd8e75d4bbf6dab83 +Author: Field G. Van Zee +Date: Tue Mar 19 16:30:23 2019 -0500 + + Fixed broken section links in docs/Performance.md. + + Details: + - Fixed a few broken section links in the Contents section. + +commit 913cf97653f5f9a40aa89a5b79e2b0a8882dd509 +Author: Field G. Van Zee +Date: Tue Mar 19 16:15:24 2019 -0500 + + Added docs/Performance.md and docs/graphs subdir. + + Details: + - Added a new markdown document, docs/Performance.md, which reports + performance of a representative set of level-3 operations across a + variety of hardware architectures, comparing BLIS to OpenBLAS and a + vendor library (MKL on Intel/AMD, ARMPL on ARM). Performance graphs, + in pdf and png formats, reside in docs/graphs. + - Updated README.md to link to new Performance.md document. + - Minor updates to CREDITS, docs/Multithreading.md. + - Minor updates to matlab scripts in test/3/matlab. + +commit 9945ef24fd758396b698b19bb4e23e53b9d95725 +Author: Field G. Van Zee +Date: Tue Mar 19 15:28:44 2019 -0500 + + Adjusted cache blocksizes for zen subconfig. + + Details: + - Adjusted the zen sub-configuration's cache blocksizes for float, + scomplex, and dcomplex based on the existing values for double. + (The previous values were taken directly from the haswell subconfig, + which targets Intel Haswell/Broadwell/Skylake systems.) + +commit d202d008d51251609d08d3c278bb6f4ca9caf8e4 +Author: Field G. Van Zee +Date: Mon Mar 18 18:18:25 2019 -0500 + + Renamed --enable-export-all to --export-shared=[]. + + Details: + - Replaced the existing --enable-export-all / --disable-export-all + configure option with --export-shared=[public|all], with the 'public' + instance of the latter corresponding to --disable-export-all and the + 'all' instance corresponding to --enable-export-all. Nothing else + semantically about the option, or its default, has changed. + +commit ff78089870f714663026a7136e696603b5259560 +Author: Field G. Van Zee +Date: Mon Mar 18 13:22:55 2019 -0500 + + Updates to docs/Multithreading.md. + + Details: + - Made extra explicit the fact that: (a) multithreading in BLIS is + disabled by default; and (b) even with multithreading enabled, the + user must specify multithreading at runtime in order to observe + parallelism. Thanks to M. Zhou for suggesting these clarifications + in #292. + - Also made explicit that only the environment variable and global + runtime API methods are available when using the BLAS API. If the + user wishes to use the local runtime API (specify multithreading on + a per-call basis), one of the native BLIS APIs must be used. + +commit 3a929a3d0ba0353159a6d4cd188f01b7a390ccfc +Author: Kiran Varaganti +Date: Mon Mar 18 10:51:41 2019 +0530 + + Fixed code merging: bli_gemm_small.c - missed conditional checks for L!=0 && K!=0. Now they are added. This fix is done to pass blastest + + Change-Id: Idc9c9a04d2015a68a19553c437ecaf8f1584026c + +commit 663f662932c3f182fefc3c77daa1bf8c3394bb8b +Merge: 938c05ef 6bfe3812 +Author: Field G. Van Zee +Date: Sat Mar 16 16:17:12 2019 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 938c05ef8654e2fc013d39a57f51d91d40cc40fb +Merge: 4ed39c09 5a5f494e +Author: Field G. Van Zee +Date: Sat Mar 16 16:01:43 2019 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 6bfe3812e29b86c95b828822e4e5473b48891167 +Author: Field G. Van Zee +Date: Fri Mar 15 13:57:49 2019 -0500 + + Use -fvisibility=[...] with clang on Linux/BSD/OSX. + + Details: + - Modified common.mk to use the -fvisibility=[hidden|default] option + when compiling with clang on non-Windows platforms (Linux, BSD, OS X, + etc.). Thanks to Isuru Fernando for pointing out this option works + with clang on these OSes. + +commit 809395649c5bbf48778ede4c03c1df705dd49566 +Author: Field G. Van Zee +Date: Wed Mar 13 18:21:35 2019 -0500 + + Annotated additional symbols for export. + + Details: + - Added export annotations to additional function prototypes in order to + accommodate the testsuite. + - Disabled calling bli_amaxv_check() from within the testsuite's + test_amaxv.c. + +commit e095926c643fd9c9c2220ebecd749caae0f71d42 +Author: Field G. Van Zee +Date: Wed Mar 13 17:35:18 2019 -0500 + + Support shared lib export of only public symbols. + + Details: + - Introduced a new configure option, --enable-export-all, which will + cause all shared library symbols to be exported by default, or, + alternatively, --disable-export-all, which will cause all symbols to + be hidden by default, with only those symbols that are annotated for + visibility, via BLIS_EXPORT_BLIS (and BLIS_EXPORT_BLAS for BLAS + symbols), to be exported. The default for this configure option is + --disable-export-all. Thanks to Isuru Fernando for consulting on + this commit. + - Removed BLIS_EXPORT_BLIS annotations from frame/1m/bli_l1m_unb_var1.h, + which was intended for 5a5f494. + - Relocated BLIS_EXPORT-related cpp logic from bli_config.h.in to + frame/include/bli_config_macro_defs.h. + - Provided appropriate logic within common.mk to implement variable + symbol visibility for gcc, clang, and icc (to the extend that each of + these compilers allow). + - Relocated --help text associated with debug option (-d) to configure + slightly further down in the list. + +commit 5a5f494e428372c7c27ed1f14802e15a83221e87 +Author: Field G. Van Zee +Date: Tue Mar 12 18:45:09 2019 -0500 + + Removed export macros from all internal prototypes. + + Details: + - After merging PR #303, at Isuru's request, I removed the use of + BLIS_EXPORT_BLIS from all function prototypes *except* those that we + potentially wish to be exported in shared/dynamic libraries. In other + words, I removed the use of BLIS_EXPORT_BLIS from all prototypes of + functions that can be considered private or for internal use only. + This is likely the last big modification along the path towards + implementing the functionality spelled out in issue #248. Thanks + again to Isuru Fernando for his initial efforts of sprinkling the + export macros throughout BLIS, which made removing them where + necessary relatively painless. Also, I'd like to thank Tony Kelman, + Nathaniel Smith, Ian Henriksen, Marat Dukhan, and Matthew Brett for + participating in the initial discussion in issue #37 that was later + summarized and restated in issue #248. + - CREDITS file update. + +commit 3dc18920b6226026406f1d2a8b2c2b405a2649d5 +Merge: b938c16b 766769ee +Author: Field G. Van Zee +Date: Tue Mar 12 11:20:25 2019 -0500 + + Merge branch 'master' into dev + +commit 766769eeb944bd28641a6f72c49a734da20da755 +Author: Isuru Fernando +Date: Mon Mar 11 19:05:32 2019 -0500 + + Export functions without def file (#303) + + * Revert "restore bli_extern_defs exporting for now" + + This reverts commit 09fb07c350b2acee17645e8e9e1b8d829c73dca8. + + * Remove symbols not intended to be public + + * No need of def file anymore + + * Fix whitespace + + * No need of configure option + + * Remove export macro from definitions + + * Remove blas export macro from definitions + +commit 4ed39c0971c7917e2675cf5449f563b1f4751ccc +Merge: 540ec1b4 b938c16b +Author: Field G. Van Zee +Date: Fri Mar 8 11:56:58 2019 -0600 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit b938c16b0c9e839335ac2c14944b82890143d02f +Author: Field G. Van Zee +Date: Thu Mar 7 16:40:39 2019 -0600 + + Renamed test/3m4m to test/3. + + Details: + - Renamed '3m4m' directory to '3', which captures the directory nicely + since it builds test drivers to test level-3 operations. + - These test drivers ceased to be used to test the 3m and 4m (or even + 1m) induced methods long ago, hence the name change. + +commit ab89a40582ec7acf802e59b0763bed099a02edd8 +Author: Field G. Van Zee +Date: Thu Mar 7 16:26:12 2019 -0600 + + More minor updates and edits to test/3m4m. + + Details: + - Further updates to matlab scripts, mostly for compatibility with + GNU Octave. + - More tweaks to runme.sh. + - Updates to runme.m that allow copy-paste into matlab interactive + session to generate graphs. + +commit f0e70dfbf3fee4c4e382c2c4e87c25454cbc79a1 +Author: Field G. Van Zee +Date: Thu Mar 7 01:04:05 2019 +0000 + + Very minor updates to test/3m4m for ul252. + + Details: + - Very minor updates to the newly revamped test/3m4m drivers when used + on a Xeon Platinum (SkylakeX). + +commit 7fe44748383071f1cbbc77d904f4ae5538e13065 +Author: Kiran Varaganti +Date: Wed Mar 6 16:23:31 2019 +0530 + + Disabled BLIS_ENABLE_ZEN_BLOCK_SIZES in bli_family_zen.h for ROME tuning + + Change-Id: Iec47fcf51f4d4396afef1ce3958e58cf02c59a57 + +commit 9f1dbe572b1fd5e7dd30d5649bdf59259ad770d5 +Author: Field G. Van Zee +Date: Tue Mar 5 17:47:55 2019 -0600 + + Overhauled test/3m4m Makefile and scripts. + + Details: + - Rewrote much of Makefile to generate executables for single- and dual- + socket multithreading as well as single-threaded. Each of the three + can also use a different problem size range/increment, as is often + appropriate when doubling/halving the number of threads. + - Rewrote runme.sh script to flexibly execute as many threading + parameter scenarios as is given in the input parameter string + (currently set within the script itself). The string also encodes + the maximum problem size for each threading scenario, which is used + to identify the executable to run. Also improved the "progress" output + of the script to reduce redundant info and improve readability in + terminals that are not especially wide. + - Minor updates to test_*.c source files. + - Updated matlab scripts according to changes made to the Makefile, + test drivers, and runme.sh script, and renamed 'plot_all.m' to + 'runme.m'. + +commit f5ed95ecd7d5eb4a63e1333ad5cc6765fc8df9fe +Author: Kiran Varaganti +Date: Tue Mar 5 15:01:57 2019 +0530 + + Merged BLIS Release 1.3 + Modified config/zen/make_defs.mk, now CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=znver1 + + Change-Id: Ia0942d285a21447cd0c470de1bc021fe63e80d81 + +commit 3bdab823fa93342895bf45d812439324a37db77c +Merge: 70f12f20 e2a02ebd +Author: Field G. Van Zee +Date: Thu Feb 28 14:07:24 2019 -0600 + + Merge branch 'master' into dev + +commit e2a02ebd005503c63138d48a2b7d18978ee29205 +Author: Field G. Van Zee +Date: Thu Feb 28 13:58:59 2019 -0600 + + Updates (from ls5) to test/3m4m/runme.sh. + + Details: + - Lonestar5-specific updates to runme.sh. + +commit f0dcc8944fa379d53770f5cae5d670140918f00c +Author: Isuru Fernando +Date: Wed Feb 27 17:27:23 2019 -0600 + + Add symbol export macro for all functions (#302) + + * initial export of blis functions + + * Regenerate def file for master + + * restore bli_extern_defs exporting for now + +commit 540ec1b479712d5e1da637a718927249c15d867f +Author: Field G. Van Zee +Date: Sun Feb 24 19:09:10 2019 -0600 + + Updated level-3 BLAS to call object API directly. + + Details: + - Updated the BLAS compatibility layer for level-3 operations so that + the corresponding BLIS object API is called directly rather than first + calling the typed BLIS API. The previous code based on the typed BLIS + API calls is still available in a deactivated cpp macro branch, which + may be re-activated by #defining BLIS_BLAS3_CALLS_TAPI. (This does not + yet correspond to a configure option. If it seems like people might + want to toggle this behavior more regularly, a configure option can be + added in the future.) + - Updated the BLIS typed API to statically "pre-initialize" objects via + new initializor macros. Initialization is then finished via calls to + static functions bli_obj_init_finish_1x1() and bli_obj_init_finish(), + which are similar to the previously-called functions, + bli_obj_create_1x1_with_attached_buffer() and + bli_obj_create_with_attached_buffer(), respectively. (The BLAS + compatibility layer updates mentioned above employ this new technique + as well.) + - Transformed certain routines in bli_param_map.c--specifically, the + ones that convert netlib-style parameters to BLIS equivalents--into + static functions, now in bli_param_map.h. (The remaining three classes + of conversation routines were left unchanged.) + - Added the aforementioned pre-initializor macros to bli_type_defs.h. + - Relocated bli_obj_init_const() and bli_obj_init_constdata() from + bli_obj_macro_defs.h to bli_type_defs.h. + - Added a few macros to bli_param_macro_defs.h for testing domains for + real/complexness and precisions for single/double-ness. + +commit 8e023bc914e9b4ac1f13614feb360b105fbe44d2 +Author: Field G. Van Zee +Date: Fri Feb 22 16:55:30 2019 -0600 + + Updates to 3m4m/matlab scripts. + + Details: + - Minor updates to matlab graph-generating scripts. + - Added a plot_all.m script that is more of a scratchpad for copying and + pasting function invocations into matlab to generate plots that are + presently of interest to us. + +commit b06244d98cc468346eb1a8eb931bc05f35ff280c +Merge: e938ff08 4c7e6680 +Author: praveeng +Date: Thu Feb 21 12:56:15 2019 +0530 + + Merge branch 'ut-austin-amd' of ssh://git.amd.com:29418/cpulibraries/er/blis into ut-austin-amd + +commit e938ff08cea3d108c84524eb129d9e89d701ea90 +Author: praveeng +Date: Thu Feb 21 12:44:38 2019 +0530 + + deleted test.txt + + Change-Id: I3871f5fe76e548bc29ec2733745b29964e829dd3 + +commit ed13ad465dcba350ad3d5e16c9cc7542e33f3760 +Author: mkv +Date: Thu Feb 21 01:04:16 2019 -0500 + + added test file for initial commit + +commit 4c7e6680832b497468cf50c2399e3ac4de0e3450 +Author: praveeng +Date: Thu Feb 21 12:44:38 2019 +0530 + + deleted test.txt + + Change-Id: I3871f5fe76e548bc29ec2733745b29964e829dd3 + +commit 95e070581c54ed2edc211874faec56055ea298c8 +Author: mkv +Date: Thu Feb 21 01:04:16 2019 -0500 + + added test file for initial commit + +commit 70f12f209bc1901b5205902503707134cf2991a0 +Author: Field G. Van Zee +Date: Wed Feb 20 16:10:10 2019 -0600 + + Changed unsafe-loop to unsafe-math optimizations. + + Details: + - Changed -funsafe-loop-optimizations (re-)introduced in 7690855 for + make_defs.mk files' CRVECFLAGS to -funsafe-math-optimizations (to + account for a miscommunication in issue #300). Thanks to Dave Love + for this suggestion and Jeff Hammond for his feedback on the topic. + +commit 7690855c5106a56e5b341a350f8db1c78caacd89 +Author: Field G. Van Zee +Date: Mon Feb 18 19:16:01 2019 -0600 + + Restored -funsafe-loop-optimizations to subconfigs. + + Details: + - Restored use of -funsafe-loop-optimizations in the definitions of + CRVECFLAGS (when using gcc), but only for sub-configurations (and + not configuration families such as amd64, intel64, and x86_64). + This more or less reverts 5190d05 and 6cf1550. + +commit 44994d1490897b08cde52a615a2e37ddae8b2061 +Author: Field G. Van Zee +Date: Mon Feb 18 18:35:30 2019 -0600 + + Disable TBM, XOP, LWP instructions in AMD configs. + + Details: + - Added -mno-tbm -mno-xop -mno-lwp to CKVECFLAGS in bulldozer, + piledriver, steamroller, and excavator configurations to explicitly + disable AMD's bulldozer-era TBM, XOP, and LWP instruction sets in an + attempt to fix the invalid instruction error that has plagued Travis + CI builds since 6a014a3. Thanks to Devin Matthews for pointing out + that the offending instruction was part of TBM (issue #300). + - Restored -O3 to piledriver configuration's COPTFLAGS. + +commit 1e5b530744c1906140d47f43c5cad235eaa619cf +Author: Field G. Van Zee +Date: Mon Feb 18 18:04:38 2019 -0600 + + Reverted piledriver COPTFLAGS from -O3 to -O2. + + Details: + - Debugging continues; changing COPTFLAGS for piledriver subconfig from + -O3 to -O2, its original value prior to 6a014a3. + +commit 6cf155049168652c512aefdd16d74e7ff39b98df +Author: Field G. Van Zee +Date: Mon Feb 18 17:29:51 2019 -0600 + + Removed -funsafe-loop-optimizations from all configs. + + Details: + - Error persists. Removed -funsafe-loop-optimizations from all remaining + sub-configurations. + +commit 5190d05a27c5fa4c7942e20094f76eb9a9785c3e +Author: Field G. Van Zee +Date: Mon Feb 18 17:07:35 2019 -0600 + + Removed -funsafe-loop-optimizations from piledriver. + + Details: + - Error persists; continuing debugging from bf0fb78c by removing + -funsafe-loop-optimizations from piledriver configuration. + +commit bf0fb78c5e575372060d22f5ceeb5b332e8978ec +Author: Field G. Van Zee +Date: Mon Feb 18 16:51:38 2019 -0600 + + Removed -funsafe-loop-optimizations from families. + + Details: + - Removed -funsafe-loop-optimizations from the configuration families + affected by 6a014a3, specifically: intel64, amd64, and x86_64. + This is part of an attempt to debug why the sde, as executed by + Travis CI, is crashing via the following error: + + TID 0 SDE-ERROR: Executed instruction not valid for specified chip + (ICELAKE): 0x9172a5: bextr_xop rax, rcx, 0x103 + +commit 6a014a3377a2e829dbc294b814ca257a2bfcb763 +Author: Field G. Van Zee +Date: Mon Feb 18 14:52:29 2019 -0600 + + Standardized optimization flags in make_defs.mk. + + Details: + - Per Dave Love's recommendation in issue #300, this commit defines + COPTFLAGS := -03 + and + CRVECFLAGS := $(CKVECFLAGS) -funsafe-loop-optimizations + in the make_defs.mk for all Intel- and AMD-based configurations. + +commit 565fa3853b381051ac92cff764625909d105644d +Author: Field G. Van Zee +Date: Mon Feb 18 11:43:58 2019 -0600 + + Redirect trsm pc, ir parallelism to ic, jr loops. + + Details: + - trsm parallelization was temporarily simplifed in 075143d to entirely + ignore any parallelism specified via the pc or ir loops. Now, any + parallelism specified to the pc loop will be redirected to the ic + loop, and any parallelism specified to the ir loop will be redirected + to the jr loop. (Note that because of inter-iteration dependencies, + trsm cannot parallelize the ir loop. Parallelism via the pc loop is + at least somewhat feasible in theory, but it would require tracking + dependencies between blocks--something for which BLIS currently lacks + the necessary supporting infrastructure.) + +commit a023c643f25222593f4c98c2166212561d030621 +Author: Field G. Van Zee +Date: Thu Feb 14 20:18:55 2019 -0600 + + Regenerated symbols in build/libblis-symbols.def. + + Details: + - Reran ./build/regen-symbols.sh after running + 'configure --enable-cblas auto' + +commit 075143dfd92194647da9022c1a58511b20fc11f3 +Author: Field G. Van Zee +Date: Thu Feb 14 18:52:45 2019 -0600 + + Added support for IC loop parallelism to trsm. + + Details: + - Parallelism within the IC loop (3rd loop around the microkernel) is + now supported within the trsm operation. This is done via a new branch + on each of the control and thread trees, which guide execution of a + new trsm-only subproblem from within bli_trsm_blk_var1(). This trsm + subproblem corresponds to the macrokernel computation on only the + block of A that contains the diagonal (labeled as A11 in algorithms + with FLAME-like partitioning), and the corresponding row panel of C. + During the trsm subproblem, all threads within the JC communicator + participate and parallelize along the JR loop, including any + parallelism that was specified for the IC loop. (IR loop parallelism + is not supported for trsm due to inter-iteration dependencies.) After + this trsm subproblem is complete, a barrier synchronizes all + participating threads and then they proceed to apply the prescribed + BLIS_IC_NT (or equivalent) ways of parallelism (and any BLIS_JR_NT + parallelism specified within) to the remaining gemm subproblem (the + rank-k update that is performed using the newly updated row-panel of + B). Thus, trsm now supports JC, IC, and JR loop parallelism. + - Modified bli_trsm_l_cntl_create() to create the new "prenode" branch + of the trsm_l cntl_t tree. The trsm_r tree was left unchanged, for + now, since it is not currently used. (All trsm problems are cast in + terms of left-side trsm.) + - Updated bli_cntl_free_w_thrinfo() to be able to free the newly shaped + trsm cntl_t trees. Fixed a potentially latent bug whereby a cntl_t + subnode is only recursed upon if there existed a corresponding + thrinfo_t node, which may not always exist (for problems too small + to employ full parallelization due to the minimum granularity imposed + by micropanels). + - Updated other functions in frame/base/bli_cntl.c, such as + bli_cntl_copy() and bli_cntl_mark_family(), to recurse on sub-prenodes + if they exist. + - Updated bli_thrinfo_free() to recurse into sub-nodes and prenodes + when they exist, and added support for growing a prenode branch to + bli_thrinfo_grow() via a corresponding set of help functions named + with the _prenode() suffix. + - Added a bszid_t field thrinfo_t nodes. This field comes in handy when + debugging the allocation/release of thrinfo_t nodes, as it helps trace + the "identity" of each nodes as it is created/destroyed. + - Renamed + bli_l3_thrinfo_print_paths() -> bli_l3_thrinfo_print_gemm_paths() + and created a separate bli_l3_thrinfo_print_trsm_paths() function to + print out the newly reconfigured thrinfo_t trees for the trsm + operation. + - Trival changes to bli_gemm_blk_var?.c and bli_trsm_blk_var?.c + regarding variable declarations. + - Removed subpart_t enum values BLIS_SUBPART1T, BLIS_SUBPART1B, + BLIS_SUBPART1L, BLIS_SUBPART1R. Then added support for two new labels + (semantically speaking): BLIS_SUBPART1A and BLIS_SUBPART1B, which + represent the subpartition ahead of and behind, respectively, + BLIS_SUBPART1. Updated check functions in bli_check.c accordingly. + - Shuffled layering/APIs for bli_acquire_mpart_[mn]dim() and + bli_acquire_mpart_t2b/b2t(), _l2r/r2l(). + - Deprecated old functions in frame/3/bli_l3_thrinfo.c. + +commit 78bc0bc8b6b528c79b11f81ea19250a1db7450ed +Author: Nicholai Tukanov +Date: Thu Feb 14 13:29:02 2019 -0600 + + Power9 sub-configuration (#298) + + Formally registered power9 sub-configuration. + + Details: + - Added and registered power9 sub-configuration into the build system. + Thanks to Nicholai Tukanov and Devangi Parikh for these contributions. + - Note: The sub-configuration does not yet have a corresponding + architecture-specific kernel set registered, and so for now the + sub-config is using the generic kernel set. + +commit 6b832731261f9e7ad003a9ea4682e9ca973ef844 +Author: Field G. Van Zee +Date: Tue Feb 12 16:01:28 2019 -0600 + + Generalized ref kernels' pragma omp simd usage. + + Details: + - Replaced direct usage of _Pragma( "omp simd" ) in reference kernels + with PRAGMA_SIMD, which is defined as a function of the compiler being + used in a new bli_pragma_macro_defs.h file. That definition is cleared + when BLIS detects that the -fopenmp-simd command line option is + unsupported. Thanks to Devin Matthews and Jeff Hammond for suggestions + that guided this commit. + - Updated configure and bli_config.h.in so that the appropriate anchor + is substituted in (when the corresponding pragma omp simd support is + present). + +commit b1f5ce8622b682b79f956fed83f04a60daa8e0fc +Author: Field G. Van Zee +Date: Tue Feb 5 17:38:50 2019 -0600 + + Minor updates to scripts in test/mixeddt/matlab. + +commit 38203ecd15b1fa50897d733daeac6850d254e581 +Author: Devangi N. Parikh +Date: Mon Feb 4 15:28:28 2019 -0500 + + Added thunderx2 system in the mixeddt test scripts + + Details: + - Added thunderx2 (tx2) as a system in the runme.sh in test/mixeddt + +commit dfc91843ea52297bf636147793029a0c1345be04 +Author: Devangi N. Parikh +Date: Mon Feb 4 15:23:40 2019 -0500 + + Fixed gcc flags for thunderx2 subconfiguration + + Details: + - Fixed -march flag. Thunderx2 is an armv8.1a architecture not armv8a. + +commit c665eb9b888ec7e41bd0a28c4c8ac4094d0a01b5 +Author: Field G. Van Zee +Date: Mon Jan 28 16:22:23 2019 -0600 + + Minor updates to docs, Makefiles. + + Details: + - Changed all occurrances of + micro-kernel -> microkernel + macro-kernel -> macrokernel + micro-panel -> micropanel + in all markdown documents in 'docs' directory. This change is being + made since we've reached the point in adoption and acceptance of + BLIS's insights where words such as "microkernel" are no longer new, + and therefore now merit being unhyphenated. + - Updated "Implementation Notes" sections of KernelsHowTo.md, which + still contained references to nonexistent cpp macros such as + BLIS_DEFAULT_MR_? and BLIS_PACKDIM_MR_?. + - Added 'run-fast' and 'check-fast' targets to testsuite/Makefile. + - Minor updates to Testsuite.md, including suggesting use of + 'make check' and 'make check-fast' when running from the local + testsuite directory. + - Added a comment to top-level Makefile explaining the purpose behind + the TESTSUITE_WRAPPER variable, which at first glance appears to serve + no purpose. + +commit 1aa280d0520ed5eaea3b119b4e92b789ecad78a4 +Author: M. Zhou <5723047+cdluminate@users.noreply.github.com> +Date: Sun Jan 27 21:40:48 2019 +0000 + + Amend OS detection for kFreeBSD. (#295) + +commit fffc23bb35d117a433886eb52ee684ff5cf6997f +Author: Field G. Van Zee +Date: Fri Jan 25 13:35:31 2019 -0600 + + CREDITS file update. + +commit 26c5cf495ce22521af5a36a1012491213d5a4551 +Author: Field G. Van Zee +Date: Thu Jan 24 18:49:31 2019 -0600 + + Fixed bug in skx subconfig related to bdd46f9. + + Details: + - Fixed code in the skx subconfiguration that became a bug after + committing bdd46f9. Specifically, the bli_cntx_init_skx() function + was overwriting default blocksizes for the scomplex and dcomplex + microkernels despite the fact that only single and double real + microkernels were being registered. This was not a problem prior to + bdd46f9 since all microkernels used dynamically-queried (at runtime) + register blocksizes for loop bounds. However, post-bdd46f9, this + became a bug because the reference ukernels for scomplex and dcomplex + were written with their register blocksizes hard-coded as constant + loop bounds, which conflicted the the erroneous scomplex and dcomplex + values that bli_cntx_init_skx() was setting in the context. The + lesson here is that going forward, all subconfigurations must not set + any blocksizes for datatypes corresponding to default/reference + microkernels. (Note that a blocksize is left unchanged by the + bli_cntx_set_blkszs() function if it was set to -1.) + +commit 180f8e42e167b83a757340ad4bd4a5c7a1d6437b +Author: Field G. Van Zee +Date: Thu Jan 24 18:01:15 2019 -0600 + + Fixed undefined behavior trsm ukr bug in bdd46f9. + + Details: + - Fixed a bug that mainfested anytime a configuration was used in which + optimized microkernels were registered and the trsm operation (or + kernel) was invoked. The bug resulted from the optimized microkernels' + register blocksizes conflicting with the hard-coded values--expressed + in the form of constant loop bounds--used in the new reference trsm + ukernels that were introduced in bdd46f9. The fix was easy: reverting + back to the implementation that uses variable-bound loops, which + amounted to changing an #if 0 to #if 1 (since I preserved the older + implementation in the file alongside the new code based on constant- + bound loops). It should be noted that this fix must be permanent, + since the trsm kernel code with constant-bound loops can never work + with gemm ukernels that use different register blocksizes. + +commit bdd46f9ee88057d52610161966a11c224e5a026c +Author: Field G. Van Zee +Date: Thu Jan 24 17:23:18 2019 -0600 + + Rewrote reference kernels to use #pragma omp simd. + + Details: + - Rewrote level-1v, -1f, and -3 reference kernels in terms of simplified + indexing annotated by the #pragma omp simd directive, which a compiler + can use to vectorize certain constant-bounded loops. (The new kernels + actually use _Pragma("omp simd") since the kernels are defined via + templatizing macros.) Modest speedup was observed in most cases using + gcc 5.4.0, which may improve with newer versions. Thanks to Devin + Matthews for suggesting this via issue #286 and #259. + - Updated default blocksizes defined in ref_kernels/bli_cntx_ref.c to + be 4x16, 4x8, 4x8, and 4x4 for single, double, scomplex and dcomplex, + respectively, with a default row preference for the gemm ukernel. Also + updated axpyf, dotxf, and dotxaxpyf fusing factors to 8, 6, and 4, + respectively, for all datatypes. + - Modified configure to verify that -fopenmp-simd is a valid compiler + option (via a new detect/omp_simd/omp_simd_detect.c file). + - Added a new header in which prefetch macros are defined according to + which compiler is detected (via macros such as __GNUC__). These + prefetch macros are not yet employed anywhere, though. + - Updated the year in copyrights of template license headers in + build/templates and removed AMD as a default copyright holder. + +commit 63de2b0090829677755eb5cdb27e73bc738da32d +Author: Field G. Van Zee +Date: Wed Jan 23 12:16:27 2019 -0600 + + Prevent redef of ftnlen in blastest f2c_types.h. + + Details: + - Guard typedef of ftnlen in f2c_types.h with a #ifndef HAVE_BLIS_H + directive to prevent the redefinition of that type. Thanks to Jeff + Diamond for reporting this compiler warning (and apologies for the + delay in committing a fix). + +commit eec2e183a7b7d67702dbd1f39c153f38148b2446 +Author: Field G. Van Zee +Date: Mon Jan 21 12:12:18 2019 -0600 + + Added escaping to '/' in os_name in configure. + + Details: + - Add os_name to the list of variables into which the '/' character is + escaped. This is meant to address (or at least make progress toward + addressing) #293. Thanks to Isuru Fernando for spotting this as the + potential fix, and also thanks to M. Zhou for the original report. + +commit adf5c17f0839fdbc1f4a1780f637928b1e78e389 +Author: Field G. Van Zee +Date: Fri Jan 18 15:14:45 2019 -0600 + + Formally registered thunderx2 subconfiguration. + + Details: + - Added a separate subconfiguration for thunderx2, which now uses + different optimization flags than cortexa57/cortexa53. + +commit 094cfdf7df6c2764c25fcbfce686ba29b933942c +Author: M. Zhou <5723047+cdluminate@users.noreply.github.com> +Date: Fri Jan 18 18:46:13 2019 +0000 + + Port BLIS to GNU Hurd OS. (#294) + + Prevent blis.h from misidentifying Hurd as OSX. + +commit 5d7d616e8e591c2f3c7c2d73220eb27ea484f9c9 +Author: Field G. Van Zee +Date: Tue Jan 15 20:52:51 2019 -0600 + + README.md update re: mixeddt TOMS paper. + +commit 58c7fb4788177487f73a3964b7a910fe4dc75941 +Author: Field G. Van Zee +Date: Tue Jan 8 17:00:27 2019 -0600 + + Added more matlab scripts for mixeddt paper. + + Details: + - Added a variant set of matlab scripts geared to producing plots that + reflect performance data gathered with and without extra memory + optimizations enabled. These scripts reside (for now) in + test/mixeddt/matlab/wawoxmem. + +commit 34286eb914b48b56cdda4dfce192608b9f86d053 +Author: Field G. Van Zee +Date: Tue Jan 8 11:41:20 2019 -0600 + + Minor update to docs/HardwareSupport.md. + +commit 108b04dc5b1b1288db95f24088d1e40407d7bc88 +Author: Field G. Van Zee +Date: Mon Jan 7 20:16:31 2019 -0600 + + Regenerated symbols in build/libblis-symbols.def. + + Details: + - Reran ./build/regen-symbols.sh after running + 'configure --enable-cblas auto' to reflect removal of + bli_malloc_pool() and bli_free_pool(). + +commit 706cbd9d5622f4690e6332a89cf41ab5c8771899 +Author: Field G. Van Zee +Date: Mon Jan 7 18:28:19 2019 -0600 + + Minor tweaks/cleanups to bli_malloc.c, _apool.c. + + Details: + - Removed malloc_ft and free_ft function pointer arguments from the + interface to bli_apool_init() after deciding that there is no need to + specify the malloc()/free() for blocks within the apool. (The apool + blocks are actually just array_t structs.) Instead, we simply call + bli_malloc_intl()/_free_intl() directly. This has the added benefit + of allowing additional output when memory tracing is enabled via + --enable-mem-tracing. Also made corresponding changes elsewhere in + the apool API. + - Changed the inner pools (elements of the array_t within the apool_t) + to use BLIS_MALLOC_POOL and BLIS_FREE_POOL instead of BLIS_MALLOC_INTL + and BLIS_FREE_INTL. + - Disabled definitions of bli_malloc_pool() and bli_free_pool() since + there are no longer any consumers of these functions. + - Very minor comment / printf() updates. + +commit 579145039d945adbcad1177b1d53fb2d3f2e6573 +Author: Minh Quan Ho <1337056+hominhquan@users.noreply.github.com> +Date: Mon Jan 7 23:00:15 2019 +0100 + + Initialize error messages at compile time (#289) + + * Initialize error messages at compile time + + - Assigning strings directly to the bli_error_string array, instead of + snprintf() at execution-time. + + * Retired bli_error_init(), _finalize(). + + Details: + - Removed functions obviated by changes in 80e8dc6: bli_error_init(), + bli_error_finalize(), and bli_error_init_msgs(), as well as calls to + the former two in bli_init.c. + + * Regenerated symbols in build/libblis-symbols.def. + + Details: + - Reran ./build/regen-symbols.sh after running + 'configure --enable-cblas auto'. + +commit aafbca086e36b6727d7be67e21fef5bd9ff7bfd9 +Author: Field G. Van Zee +Date: Mon Jan 7 12:38:21 2019 -0600 + + Updated external package language in README.md. + + Details: + - Updated/added comments about Fedora, OpenSUSE, and GNU Guix under the + newly-renamed "External GNU/Linux packages" section. Thanks to Dave + Love for providing these revisions. + +commit daacfe68404c9cc8078e5e7ba49a8c7d93e8cda3 +Author: Field G. Van Zee +Date: Mon Jan 7 12:12:47 2019 -0600 + + Allow running configure with python 3.4. + + Details: + - Relax version blacklisting of python3 to allow 3.4 or later instead + of 3.5 or later. Thanks to Dave Love for pointing out that 3.4 was + sufficient for the purpose of BLIS's build system. (It should be + noted that we're not sure which, if any, python3 versions prior to + 3.4 are insufficient, and that the only thing stopping us from + determining this is the fact that these earlier versions of python3 + are not readily available for us to test with.) + - Updated docs/BuildSystem.md to be explicit about current python2 vs + python3 version requirements. + +commit cdbf16aa93234e0d6a80f0d0e385ec81e7b75465 +Author: prangana +Date: Fri Jan 4 15:59:21 2019 +0530 + + Update version 1.3 + + Change-Id: I32a7d24af860e87a60396614075236afb65a28a9 + +commit cf9c1150515b8e9cc4f12e0d4787b3471b12ba4a +Author: kdevraje +Date: Thu Jan 3 09:51:46 2019 +0530 + + This commit adds a macro, which is to be enabled when BLIS is working on single instance mode + + Change-Id: I7f3fd654b78e64c4e6e24e9f0e245b1a30c492b0 + +commit ad8d9adb09a7dd267bbdeb2bd1fbbf9daf64ee76 +Author: Field G. Van Zee +Date: Thu Jan 3 16:08:24 2019 -0600 + + README.md, CREDITS update. + + Details: + - Added "What's New" and "What People Are Saying About BLIS" sections to + README.md. + - Added missing github handles to various individuals' entries in the + CREDITS file. + +commit 7052fca5aef430241278b67d24cef6fe33106904 +Author: Field G. Van Zee +Date: Wed Jan 2 13:48:40 2019 -0600 + + Apply f272c289 to bli_fmalloc_noalign(). + + Details: + - Perform the same check for NULL return values and error message output + in bli_fmalloc_noalign() as is performed by bli_fmalloc_align(). (This + change was intended for f272c289.) + +commit 528e3ad16a42311a852a8376101959b4ccd801a5 +Merge: 3126c52e f272c289 +Author: Field G. Van Zee +Date: Wed Jan 2 13:39:19 2019 -0600 + + Merge branch 'amd' + +commit 3126c52ea795ffb7d30b16b7f7ccc2a288a6158d +Merge: 61441b24 8091998b +Author: Field G. Van Zee +Date: Wed Jan 2 13:37:37 2019 -0600 + + Merge branch 'amd' + +commit f272c2899a6764eedbe05cea874ee3bd258dbff3 +Author: Field G. Van Zee +Date: Wed Jan 2 12:34:15 2019 -0600 + + Add error message to malloc() check for NULL. + + Details: + - Output an error message if and when the malloc()-equivalent called by + bli_fmalloc_align() ever returns NULL. Everything was already in place + for this to happen, including the error return code, the error string + sprintf(), the error checking function bli_check_valid_malloc_buf() + definition, and its prototype. Thanks to Minh Quan Ho for pointing out + the missing error message. + - Increased the default block_ptrs_len for each inner pool stored in the + small block allocator from 10 to 25. Under normal execution, each + thread uses only 21 blocks, so this change will prevent the sba from + needing to resize the block_ptrs array of any given inner pool as + threads initially populate the pool with small blocks upon first + execution of a level-3 operation. + - Nix stray newline echo in configure. + +commit eb97f778a1e13ee8d3b3aade05e479c4dfcfa7c0 +Author: Field G. Van Zee +Date: Tue Dec 25 20:17:09 2018 -0600 + + Added missing AMD copyrights to previous commit. + + Details: + - Forgot to add AMD copyrights to several touched files that did not + already have them in 2f31743. + +commit 2f3174330fb29164097d664b7c84e05c7ced7d95 +Author: Field G. Van Zee +Date: Tue Dec 25 19:35:01 2018 -0600 + + Implemented a pool-based small block allocator. + + Details: + - Implemented a sophisticated data structure and set of APIs that track + the small blocks of memory (around 80-100 bytes each) used when + creating nodes for control and thread trees (cntl_t and thrinfo_t) as + well as thread communicators (thrcomm_t). The purpose of the small + block allocator, or sba, is to allow the library to transition into a + runtime state in which it does not perform any calls to malloc() or + free() during normal execution of level-3 operations, regardless of + the threading environment (potentially multiple application threads + as well as multiple BLIS threads). The functionality relies on a new + data structure, apool_t, which is (roughly speaking) a pool of + arrays, where each array element is a pool of small blocks. The outer + pool, which is protected by a mutex, provides separate arrays for each + application thread while the arrays each handle multiple BLIS threads + for any given application thread. The design minimizes the potential + for lock contention, as only concurrent application threads would + need to fight for the apool_t lock, and only if they happen to begin + their level-3 operations at precisely the same time. Thanks to Kiran + Varaganti and AMD for requesting this feature. + - Added a configure option to disable the sba pools, which are enabled + by default; renamed the --[dis|en]able-packbuf-pools option to + --[dis|en]able-pba-pools; and rewrote the --help text associated with + this new option and consolidated it with the --help text for the + option associated with the sba (--[dis|en]able-sba-pools). + - Moved the membrk field from the cntx_t to the rntm_t. We now pass in + a rntm_t* to the bli_membrk_acquire() and _release() APIs, just as we + do for bli_sba_acquire() and _release(). + - Replaced all calls to bli_malloc_intl() and bli_free_intl() that are + used for small blocks with calls to bli_sba_acquire(), which takes a + rntm (in addition to the bytes requested), and bli_sba_release(). + These latter two functions reduce to the former two when the sba pools + are disabled at configure-time. + - Added rntm_t* arguments to various cntl_t and thrinfo_t functions, as + required by the new usage of bli_sba_acquire() and _release(). + - Moved the freeing of "old" blocks (those allocated prior to a change + in the block_size) from bli_membrk_acquire_m() to the implementation + of the pool_t checkout function. + - Miscellaneous improvements to the pool_t API. + - Added a block_size field to the pblk_t. + - Harmonized the way that the trsm_ukr testsuite module performs packing + relative to that of gemmtrsm_ukr, in part to avoid the need to create + a packm control tree node, which now requires a rntm_t that has been + initialized with an sba and membrk. + - Re-enable explicit call bli_finalize() in testsuite so that users who + run the testsuite with memory tracing enabled can check for memory + leaks. + - Manually imported the compact/minor changes from 61441b24 that cause + the rntm to be copied locally when it is passed in via one of the + expert APIs. + - Reordered parameters to various bli_thrcomm_*() functions so that the + thrcomm_t* to the comm being modified is last, not first. + - Added more descriptive tracing for allocating/freeing small blocks and + formalized via a new configure option: --[dis|en]able-mem-tracing. + - Moved some unused scalm code and headers into frame/1m/other. + - Whitespace changes to bli_pthread.c. + - Regenerated build/libblis-symbols.def. + +commit 61441b24f3244a4b202c29611a4899dd5c51d3a1 +Author: Field G. Van Zee +Date: Thu Dec 20 19:38:11 2018 -0600 + + Make local copy of user's rntm_t in level-3 ops. + + Details: + - In the case that the caller passes in a non-NULL rntm_t pointer into + one of the expert APIs for a level-3 operation (e.g. bli_gemm_ex()), + make a local copy of the rntm_t and use the address of that local copy + in all subsequent execution (which may change the contents of the + rntm_t). This prevents a potentially confusing situation whereby a + user-initialized rntm_t is used once (in, say, gemm), and then found + by the user to be in a different state before it is used a second + time. + +commit e809b5d2f1023b4249969e2f516291c9a3a00b80 +Merge: 76016691 0476f706 +Author: Field G. Van Zee +Date: Thu Dec 20 16:27:26 2018 -0600 + + Merge branch 'master' into amd + +commit 1f4eeee5175a8fc9ac312847c796ce6db5fe75b9 +Author: sraut +Date: Wed Dec 19 21:21:10 2018 +0530 + + Fixed BLAS test failures of small matrix SYRK for single and double precision. + + Details: + - SYRK for small matrix was implemented by reusing small GEMM routine. This was + resulting in output written to the full C matrix, and C being symmetric the + lower and upper triangles of C matrix contained same results. BLAS SYRK API + spec demands either lower or upper triangle of C matrix to be written with + results. So, this was resulting in BLAS test failures, even though testsuite + of BLIS was passing small SYRK operation. + - To fix BLAS test failures of small matrix SYRK, separate kernel routines are + implemented for small SYRK for both single and double precision. The newly + added small SYRK routines are in file kernels/zen/3/bli_syrk_small.c. + Now the intermediate results of matrix C are written to a scratch buffer. + Final results are written from scratch buffer to matrix C using SIMD + copy to either lower or upper traingle part of matrix C. + - Source and header files frame/3/syrk/bli_syrk_front.c and + frame/3/syrk/bli_syrk_front.h are changed to invoke new small SYRK routines. + + Change-Id: I9cfb1116c93d150aefac673fca033952ecac97cb + +commit 6d267375c3a0543f20604d74cc678ad91db3b6f1 +Author: sraut +Date: Wed Dec 19 14:22:21 2018 +0530 + + This commit improves the performance of multi-instance DGEMM when these multiple threads are binded to a CCX. + Multi-Instance: Each thread runs a sequential DGEMM. + Change-Id: I306920c8061b6dad61efac1dae68727f4ac27df6 + +commit 0476f706b93e83f6b74a3d7b7e6e9cc9a1a52c3b +Author: Field G. Van Zee +Date: Tue Dec 18 14:56:20 2018 -0600 + + CHANGELOG update (0.5.1) + +commit e0408c3ca3d53bc8e6fedac46ea42c86e06c922d (tag: 0.5.1) Author: Field G. Van Zee Date: Tue Dec 18 14:56:16 2018 -0600 Version file update (0.5.1) -commit 3ab231afc9f69d14493908c53c85a84c5fba58aa (origin/master, origin/HEAD) +commit 3ab231afc9f69d14493908c53c85a84c5fba58aa Author: Field G. Van Zee Date: Tue Dec 18 14:53:37 2018 -0600 @@ -24,6 +5530,16 @@ Date: Tue Dec 18 14:52:40 2018 -0600 into Debian package universe. Thanks to M. Zhou for sponsoring BLIS in Debian. +commit 7bf901e9265a1acd78e44c06f7178c8152c7e267 +Author: sraut +Date: Tue Dec 18 14:39:16 2018 +0530 + + Fix on EPYC machine for multi instance performance issue, + Issue: For the default values of mc, kc and nc with multi instance mode the performance across the cores dip drastically. + Fix: After experimentation found different set of values (mc, kc and nc) which fits in the cache size, and performance across the remains same across all the cores. + + Change-Id: I98265e3b7e61cd7602a0cc5596240e86c08c03fe + commit d2b2a0819a2fccad9165bc48c0e172d79a87542c Author: Field G. Van Zee Date: Mon Dec 17 19:26:35 2018 -0600 @@ -53,6 +5569,55 @@ Date: Mon Dec 17 19:17:30 2018 -0600 OpenMP. - CREDITS file update. +commit 76016691e2c514fcb59f940c092475eda968daa2 +Author: Field G. Van Zee +Date: Thu Dec 13 17:23:09 2018 -0600 + + Improvements to bli_pool; malloc()/free() tracing. + + Details: + - Added malloc_ft and free_ft fields to pool_t, which are provided when + the pool is initialized, to allow bli_pool_alloc_block() and + bli_pool_free_block() to call bli_fmalloc_align()/bli_ffree_align() + with arbitrary align_size values (according to how the pool_t was + initialized). + - Added a block_ptrs_len argument to bli_pool_init(), which allows the + caller to specify an initial length for the block_ptrs array, which + previously suffered the cost of being reallocated, copied, and freed + each time a new block was added to the pool. + - Consolidated the "buf_sys" and "buf_align" pointer fields in pblk_t + into a single "buf" field. Consolidated the bli_pblk API accordingly + and also updated the bli_mem API implementation. This was done + because I'd previously already implemented opaque alignment via + bli_malloc_align(), which allocates extra space and stores the + original pointer returned by malloc() one element before the element + whose address is aligned. + - Tweaked bli_membrk_acquire_m() and bli_membrk_release() to call + bli_fmalloc_align() and bli_ffree_align(), which required adding an + align_size field to the membrk_t struct. + - Pass the pack schemas directly into bli_l3_cntl_create_if() rather + than transmit them via objects for A and B. + - Simplified bli_l3_cntl_free_if() and renamed to bli_l3_cntl_free(). + The function had not been conditionally freeing control trees for + quite some time. Also, removed obj_t* parameters since they aren't + needed anymore (or never were). + - Spun-off OpenMP nesting code in bli_l3_thread_decorator() to a + separate function, bli_l3_thread_decorator_thread_check(). + - Renamed: + bli_malloc_align() -> bli_fmalloc_align() + bli_free_align() -> bli_ffree_align() + bli_malloc_noalign() -> bli_fmalloc_noalign() + bli_free_noalign() -> bli_ffree_noalign() + The 'f' is for "function" since they each take a malloc_ft or free_ft + function pointer argument. + - Inserted various printf() calls for the purposes of tracing memory + allocation and freeing, guarded by cpp macro ENABLE_MEM_DEBUG, which, + for now, is intended to be a "hidden" feature rather than one hooked + up to a configure-time option. + - Defined bli_rntm_equals(), which compares two rntm_t for equality. + (There are no use cases for this function yet, but there may be soon.) + - Whitespace changes to function parameter lists in bli_pool.c, .h. + commit f808d829c58dc4194cc3ebc3825fbdde12cd3f93 Author: Field G. Van Zee Date: Wed Dec 12 15:22:59 2018 -0600 @@ -105,6 +5670,13 @@ Date: Wed Dec 12 15:22:59 2018 -0600 - Fixed a minor bug in the testsuite that prevented non-1m-based induced method implementations of trsm from executing. +commit 02ec0be3ba0b0d6b4186386ae140906a96de919b +Merge: e275def3 c534da62 +Author: Field G. Van Zee +Date: Wed Dec 5 19:33:53 2018 -0600 + + Merge branch 'master' into amd + commit c534da62c0015f91391983da5376c9e091378010 Author: Field G. Van Zee Date: Wed Dec 5 15:51:05 2018 -0600 @@ -149,7 +5721,7 @@ Date: Wed Dec 5 20:06:32 2018 +0000 (That is, when native complex microkernels are missing, we usually want to test performance of 1m.) -commit 0645f239fbdf37ee9d2096ee3bb0e76b3302cfff (origin/dev, dev) +commit 0645f239fbdf37ee9d2096ee3bb0e76b3302cfff Author: Field G. Van Zee Date: Tue Dec 4 14:31:06 2018 -0600 @@ -238,6 +5810,13 @@ Date: Mon Dec 3 17:49:52 2018 -0600 frame/3/gemm/ind/bli_gemm_ind_opt.h. - Various whitespace/comment updates. +commit e275def30ac41cadce296560fa67282704f20a02 +Merge: 8091998b dc184095 +Author: Field G. Van Zee +Date: Fri Nov 30 15:39:50 2018 -0600 + + Merge branch 'master' into amd + commit dc18409551f341125169fe8d4d43ac45e81bdf28 Author: Field G. Van Zee Date: Wed Nov 28 11:58:40 2018 -0600 @@ -489,6 +6068,13 @@ Date: Wed Nov 14 13:47:45 2018 -0600 Isuru Fernando for suggesting this fix, and also to Costas Yamin for originally reporting the issue (#277). +commit 8091998b6500e343c2024561c2b1aa73c3bafb0b +Merge: 333d8562 7b5ba731 +Author: Field G. Van Zee +Date: Wed Nov 14 12:36:35 2018 -0600 + + Merge branch 'master' into amd + commit 7b5ba7319b3901ad0e6c6b4fa3c1d96b579efbe9 Merge: ce719f81 52392932 Author: Field G. Van Zee @@ -548,6 +6134,18 @@ Date: Tue Nov 13 13:03:15 2018 -0600 datatype contains a different value. Thanks to Devangi Parikh for helping in isolating this bug. +commit 333d8562f04eea0676139a10cb80a97f107b45b0 +Author: Field G. Van Zee +Date: Sun Nov 11 14:28:53 2018 -0600 + + Added debug output to bli_malloc.c. + + Details: + - Added debug output to bli_malloc.c in order to debug certain kinds of + memory behavior in BLIS. The printf() statements are disabled and must + be enabled manually. + - Whitespace/comment updates in bli_membrk.c. + commit ce719f816d1237f5277527d7f61123e77180be54 Author: Field G. Van Zee Date: Sat Nov 10 14:48:43 2018 -0600 @@ -644,7 +6242,7 @@ Date: Fri Oct 26 17:07:15 2018 -0500 output. - Very minor edits to docs/MixedDatatypes.md. -commit e90e7f309b3f2760a01e8e09a29bf702754fa2b5 (origin/win-pthreads, win-pthreads) +commit e90e7f309b3f2760a01e8e09a29bf702754fa2b5 (origin/win-pthreads) Author: Field G. Van Zee Date: Thu Oct 25 14:09:43 2018 -0500 @@ -754,6 +6352,14 @@ Date: Tue Oct 23 19:16:54 2018 -0500 - Removed temporary play-test code for shiftd that accidentally got committed into test/3m4m/test_gemm.c. +commit 0ae9585da1e3db1cf8034d4b16305a5883beb0d3 +Author: pradeeptrgit +Date: Tue Oct 23 09:36:23 2018 +0530 + + Update version number to 1.2 + + Change-Id: Ibb31f6683cdecca6b218bc2f0c14701d7e92ebf3 + commit eac7d267a017d646a2c5b4fa565f4637ebfd9da7 Author: Field G. Van Zee Date: Mon Oct 22 18:10:59 2018 -0500 @@ -1229,6 +6835,14 @@ Date: Thu Oct 11 10:45:07 2018 -0500 Detect when OpenMP uses fewer threads than requested and correct accordingly, so that we don't wait forever for nonexistent threads. Fixes #267. +commit 78a6935483409ae277c766406e175772e820b1de +Author: sraut +Date: Thu Oct 11 10:49:40 2018 +0530 + + Added comments for the change in syrk small matrix change. + + Change-Id: I958939e9953323730da49ef07d1b10e578837d82 + commit 53a9ab1c85be14dcfd2560f5b16e898e3e258797 Author: Field G. Van Zee Date: Wed Oct 10 15:11:09 2018 -0500 @@ -1369,6 +6983,17 @@ Date: Thu Oct 4 20:39:06 2018 -0500 bli_clock_min_diff() is called. Thanks to Kiran Varaganti for reporting this issue. +commit f0c3ef359f7c6c1687fb2671cb35deb346e00597 +Author: Kiran V +Date: Thu Oct 4 16:32:21 2018 +0530 + + This is a fix to floating-point exception error for BLIS SGEMM with larger matrix sizes. + BUG No: CPUPL-197 fixed by Thangaraj Santanu + The bli_clock_min_diff() function in BLIS assumed that if the time taken is greater than 1 hour then the reading must be wrong. However this is not the case in general, while the other checks such as time taken closer to zero or nsec is ofcourse valid. + gerrit review: http://git.amd.com:8080/#/c/118694/1/frame/base/bli_clock.c + + Change-Id: I9dc313d7c5fdc20684f67a516bf3237de3e0694a + commit 8bf30eb4735872388b5317883d99b775a344ce25 Author: Devangi N. Parikh Date: Wed Oct 3 22:22:29 2018 -0400 @@ -1414,6 +7039,14 @@ Date: Wed Oct 3 13:57:25 2018 -0500 long to fit on a single line. - Changed some links from http to https. +commit 80a8b3dd8034ec8bc03d31be3f9c837c3f6fc94b +Author: sraut +Date: Wed Oct 3 15:30:33 2018 +0530 + + Review comments incorporated for small TRSM. + + Change-Id: Ia64b7b2c0375cc501c2cb0be8a1af93111808cd9 + commit b8dfd82e0d1afda4ee5436662d63515a59b2dee3 Author: Devin Matthews Date: Tue Oct 2 15:37:12 2018 -0500 @@ -1498,6 +7131,22 @@ Date: Mon Oct 1 14:04:30 2018 -0500 Details: - Added language mentioning SHPC group to Introduction. +commit ee46fa3efb6e920fa6c3d0b0601007f5de31deb5 +Author: sraut +Date: Mon Oct 1 16:30:30 2018 +0530 + + Small TRSM optimization changes :- 1) single precision small trsm kernels for XAt=B case are further optimized for performance. 2) double precision small trsm kernels for AX=B and XAtB cases are implemented. 3) single precision small trsm kernels for AutX=B are implemented in intrinsics to improve the current performance. + + Change-Id: Ic9d67ae6d8522615257dde018903f049dcffa2cf + +commit 08045a6c52b6e025652c5b18eb120c0f4e61cf6f +Author: sraut +Date: Mon Oct 1 15:38:23 2018 +0530 + + Corrected the fix made for blastest level-3 failure to check m,n,k non-zero condition in bli_gemm_small.c + + Change-Id: Idaf9f2327c3127b04a2738ae8a058b83d6c57934 + commit ac18949a4b9613741b9ea8e5026d8083acef6fe4 Author: Field G. Van Zee Date: Sun Sep 30 18:54:56 2018 -0500 @@ -1562,6 +7211,23 @@ Date: Fri Sep 28 11:25:54 2018 -0500 no longer needed (issue #257). Thanks to M. Zhou and Nico Schlömer for their contributions. +commit 9814cfdf3157ef4726ee604fc895d56e8063d765 +Author: Meghana +Date: Fri Sep 28 11:02:39 2018 +0530 + + fixed blastest level-3 failure by adding ((M&N&K) != 0) to check condition in bli_gemm_small.c + + Change-Id: I85e4a32996ebb880f3c00bd293edc38f74700fe6 + +commit 86330953b14c180862deef3ccdcc6431259be27b +Merge: 7af5283d 807a6548 +Author: praveeng +Date: Fri Sep 28 10:08:06 2018 +0530 + + Resolved conflicts and modified bli_trsm_small.c + + Change-Id: I578d419cff658003e0fdd4c4cdc93145d951ce31 + commit 60b2650d7406d266feffe232c2d5692a9e3886d0 Author: Field G. Van Zee Date: Mon Sep 24 15:04:45 2018 -0500 @@ -3485,6 +9151,14 @@ Date: Mon Jun 11 12:32:54 2018 -0500 functions were the ones to inherit the 1r functionality. The kernels have now been renamed to use a _1er suffix. +commit 7af5283dcc3dded114852d6013d33134021b81aa +Author: sraut +Date: Mon Jun 11 15:00:22 2018 +0530 + + added check condition on n-dimension for XA'=B intrinsic code to process till 128 size + + Change-Id: I95d020a5ca3ea21d446b8c2e379d56e1eea18530 + commit 712de9b371a8727682352a2f52cd4880de905f0b Author: Field G. Van Zee Date: Sat Jun 9 14:36:30 2018 -0500 @@ -3659,6 +9333,22 @@ Date: Wed Jun 6 15:35:05 2018 -0500 comment (I used "sed //d" to remove the lines). This fixes the broken 'make checkblis-fast' (and 'make check') targets. +commit 695cd520e2f5eab938f66afe9fe36201ab2700c5 +Author: sraut +Date: Wed Jun 6 11:48:56 2018 +0530 + + AMD Copyright information changed to 2018 + + Change-Id: Idfd11afd5d252f8063d0158680d24bf7e2854469 + +commit df1dd24fd896821de60917b429f303bab7fd0d4b +Author: sraut +Date: Wed Jun 6 11:24:33 2018 +0530 + + small matrix trsm intrinsics optimization code for AX=B and XA'=B + + Change-Id: I90123c4d9adbd314c867995cd19dc975150b448c + commit 3f48c38164b4135515b5c752c506fdccc4480be2 Author: Field G. Van Zee Date: Tue Jun 5 16:52:35 2018 -0500 @@ -3699,6 +9389,22 @@ Date: Tue Jun 5 14:17:39 2018 +0200 Make bli_auxinfo_next_b() return b_next, not a_next (#216) +commit d4c24ea5f644eb635046e7fe249d3e8e58b4c98a +Author: sraut +Date: Tue Jun 5 15:42:59 2018 +0530 + + copyright message changed to 2018 + + Change-Id: I33c1ebda41bc7f1973ff19e3b1947bdad62b4d44 + +commit 3f1ba4e646776699ebfaa042fe24691d9e2f55d0 +Author: sraut +Date: Tue Jun 5 14:21:13 2018 +0530 + + copyright changed to 2018 + + Change-Id: Ie916c7cd6f95aedc3cab6eec3a703c9ddb333bc3 + commit bd02c4e9f7fe07487276e61507335d48c8e05f35 Author: Field G. Van Zee Date: Mon Jun 4 13:42:17 2018 -0500 @@ -5332,6 +11038,22 @@ Date: Tue Mar 20 13:54:58 2018 -0500 - Renamed some targets in the top-level Makefile to be consistent between BLAS and BLIS. +commit fc53ad6c5b2e39238b1bbbf625cc0c638b9da4e1 +Author: Nisanth M P +Date: Mon Mar 19 12:49:26 2018 +0530 + + Re-enabling the small matrix gemm optimization for target zen + + Change-Id: I13872784586984634d728cd99a00f71c3f904395 + +commit d12d34e167d7dc32732c0ed135f8065a55088106 +Author: Nisanth M P +Date: Mon Mar 19 11:34:32 2018 +0530 + + Re-enabling Zen optimized cache block sizes for config target zen + + Change-Id: I8191421b876755b31590323c66156d4a814575f1 + commit 40fa10396c0a3f9601cf49f6b6cd9922185c932e Author: Field G. Van Zee Date: Mon Mar 19 18:19:43 2018 -0500 @@ -5536,6 +11258,15 @@ Date: Sun Mar 11 16:59:50 2018 -0500 bli_dgemm_cortexa57_asm_6x8 -> bli_dgemm_armv8a_asm_6x8 Thanks to Jacob Gorm Hansen for reporting this issue. +commit 28bcea37dfcf0eb99a99da6f46de2a2830393d1d +Merge: b1ea3092 8b0475a8 +Author: praveeng +Date: Fri Mar 9 19:13:08 2018 +0530 + + Merge master code till 06_mar_2018 to amd-staging + + Change-Id: I12267e5999c92417e3715fef4f36ac2131d00f1a + commit 48da9f5805f0a49f6ad181ae2bf57b4fde8e1b0a Author: Field G. Van Zee Date: Wed Mar 7 12:54:06 2018 -0600 @@ -5607,6 +11338,14 @@ Date: Sat Mar 3 13:13:39 2018 -0600 kernels in kernels/knl/1m/. Thanks to Dave Love for reporting this issue. +commit b1ea30925dff751eced23dfa94ff578a20ea0b94 +Author: Field G. Van Zee +Date: Fri Feb 23 17:42:48 2018 -0600 + + CHANGELOG update (0.3.0) + + Change-Id: Id038b00a62de51c9818ad249651ec5dc662f4415 + commit 1ef9360b1fd0209fbeb5766f7a35402fbd080fcb Author: Field G. Van Zee Date: Thu Mar 1 14:36:39 2018 -0600 @@ -5659,17 +11398,17 @@ Date: Wed Feb 28 15:30:14 2018 -0600 bli_cgemm_zen_asm_3x8() and bli_zgemm_zen_asm_3x4(), in bli_cntx_init_zen.c. This was actually intended for 1681333. -commit d9079655c9cbb903c6761d79194a21b7c0a322bc +commit 709f8361ebc90b96b02ebe5c5ffb6fc3b1b25e58 (tag: 0.3.0) Author: Field G. Van Zee Date: Fri Feb 23 17:42:48 2018 -0600 - CHANGELOG update (0.3.0) + Version file update (0.3.0) -commit 709f8361ebc90b96b02ebe5c5ffb6fc3b1b25e58 (tag: 0.3.0) +commit d9079655c9cbb903c6761d79194a21b7c0a322bc Author: Field G. Van Zee Date: Fri Feb 23 17:42:48 2018 -0600 - Version file update (0.3.0) + CHANGELOG update (0.3.0) commit 3defc7265c12cf85e9de2d7a1f243c5e090a6f9d Author: Field G. Van Zee diff --git a/CREDITS b/CREDITS index 52efd782fb..b701598cff 100644 --- a/CREDITS +++ b/CREDITS @@ -5,81 +5,117 @@ Acknowledgements The BLIS framework was primarily authored by - Field Van Zee @fgvanzee (The University of Texas at Austin) + Field Van Zee @fgvanzee (The University of Texas at Austin) but many others have contributed code and feedback, including - Murtaza Ali (Texas Instruments) - Sajid Ali @s-sajid-ali (Northwestern University) - Erling Andersen @erling-d-andersen - Alex Arslan @ararslan - Vernon Austel (IBM, T.J. Watson Research Center) - Matthew Brett @matthew-brett (University of Birmingham) - Jed Brown @jedbrown (Argonne National Laboratory) - Robin Christ @robinchrist - Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) - Jeff Diamond (Oracle) - Johannes Dieterich @iotamudelta - Krzysztof Drewniak @krzysz00 - Marat Dukhan @Maratyszcza (Google) - Victor Eijkhout @VictorEijkhout (Texas Advanced Computing Center) - Isuru Fernando @isuruf - Roman Gareev @gareevroman - Richard Goldschmidt @SuperFluffy + Sameer Agarwal @sandwichmaker (Google) + Murtaza Ali (Texas Instruments) + Sajid Ali @s-sajid-ali (Northwestern University) + Erling Andersen @erling-d-andersen + Alex Arslan @ararslan + Vernon Austel (IBM, T.J. Watson Research Center) + Satish Balay @balay (Argonne National Laboratory) + Matthew Brett @matthew-brett (University of Birmingham) + Jérémie du Boisberranger @jeremiedbb + Jed Brown @jedbrown (Argonne National Laboratory) + Robin Christ @robinchrist + Dilyn Corner @dilyn-corner + Mat Cross @matcross (NAG) + @decandia50 + Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) + Jeff Diamond (Oracle) + Johannes Dieterich @iotamudelta + Krzysztof Drewniak @krzysz00 + Marat Dukhan @Maratyszcza (Google) + Victor Eijkhout @VictorEijkhout (Texas Advanced Computing Center) + Evgeny Epifanovsky @epifanovsky (Q-Chem) + Isuru Fernando @isuruf + Roman Gareev @gareevroman + Richard Goldschmidt @SuperFluffy Chris Goodyer - John Gunnels @jagunnels (IBM, T.J. Watson Research Center) - Ali Emre Gülcü @Lephar - Jeff Hammond @jeffhammond (Intel) - Jacob Gorm Hansen @jacobgorm - Jean-Michel Hautbois @jhautbois - Ian Henriksen @insertinterestingnamehere (The University of Texas at Austin) - Minh Quan Ho @hominhquan - Matthew Honnibal @honnibal - Stefan Husmann @stefanhusmann - Francisco Igual @figual (Universidad Complutense de Madrid) - Tony Kelman @tkelman - Lee Killough @leekillough (Cray) - Mike Kistler @mkistler (IBM, Austin Research Laboratory) - Michael Lehn @michael-lehn - Dave Love @loveshack - Tze Meng Low (The University of Texas at Austin) - Ye Luo @ye-luo (Argonne National Laboratory) - Ricardo Magana @magania (Hewlett Packard Enterprise) - Bryan Marker @bamarker (The University of Texas at Austin) - Devin Matthews @devinamatthews (The University of Texas at Austin) - Stefanos Mavros @smavros - Nisanth Padinharepatt (AMD) - Devangi Parikh @dnparikh (The University of Texas at Austin) - Elmar Peise @elmar-peise (RWTH-Aachen) - Clément Pernet @ClementPernet + John Gunnels @jagunnels (IBM, T.J. Watson Research Center) + Ali Emre Gülcü @Lephar + Jeff Hammond @jeffhammond (Intel) + Jacob Gorm Hansen @jacobgorm + Shivaprashanth H (Global Edge) + Jean-Michel Hautbois @jhautbois + Ian Henriksen @insertinterestingnamehere (The University of Texas at Austin) + Greg Henry (Intel) + Minh Quan Ho @hominhquan + Matthew Honnibal @honnibal + Stefan Husmann @stefanhusmann + Francisco Igual @figual (Universidad Complutense de Madrid) + Tony Kelman @tkelman + Lee Killough @leekillough (Cray) + Mike Kistler @mkistler (IBM, Austin Research Laboratory) + Ivan Korostelev @ivan23kor (University of Alberta) + Kyungmin Lee @kyungminlee (Ohio State University) + Michael Lehn @michael-lehn + Shmuel Levine @ShmuelLevine + @lschork2 + Dave Love @loveshack + Tze Meng Low (The University of Texas at Austin) + Ye Luo @ye-luo (Argonne National Laboratory) + Ricardo Magana @magania (Hewlett Packard Enterprise) + Madan mohan Manokar @madanm3 (AMD) + Giorgos Margaritis + Bryan Marker @bamarker (The University of Texas at Austin) + Simon Lukas Märtens @ACSimon33 (RWTH Aachen University) + Devin Matthews @devinamatthews (The University of Texas at Austin) + Stefanos Mavros @smavros + Mithun Mohan @MithunMohanKadavil (AMD) + Ilknur Mustafazade @Runkli + @nagsingh + Bhaskar Nallani @BhaskarNallani (AMD) + Stepan Nassyr @stepannassyr (Jülich Supercomputing Centre) + Nisanth Padinharepatt (AMD) + Ajay Panyala @ajaypanyala + Marc-Antoine Parent @maparent (Conversence) + Devangi Parikh @dnparikh (The University of Texas at Austin) + Elmar Peise @elmar-peise (RWTH-Aachen) + Clément Pernet @ClementPernet Ilya Polkovnichenko - Jack Poulson @poulson (Stanford) - Mathieu Poumeyrol @kali - @qnerd - Michael Rader @mrader1248 - Pradeep Rao @pradeeptrgit (AMD) + Jack Poulson @poulson (Stanford) + Mathieu Poumeyrol @kali + Christos Psarras @ChrisPsa (RWTH Aachen University) + @pkubaj + @qnerd + Michael Rader @mrader1248 + Pradeep Rao @pradeeptrgit (AMD) Aleksei Rechinskii - Karl Rupp @karlrupp - Martin Schatz (The University of Texas at Austin) - Nico Schlömer @nschloe + Karl Rupp @karlrupp + Martin Schatz (The University of Texas at Austin) + Nico Schlömer @nschloe Rene Sitt - Tony Skjellum @tonyskjellum (The University of Tennessee at Chattanooga) - Mikhail Smelyanskiy (Intel, Parallel Computing Lab) - Nathaniel Smith @njsmith - Shaden Smith @ShadenSmith - Tyler Smith @tlrmchlsmth (The University of Texas at Austin) - Paul Springer @springer13 (RWTH-Aachen) + Tony Skjellum @tonyskjellum (The University of Tennessee at Chattanooga) + Mikhail Smelyanskiy (Intel, Parallel Computing Lab) + Nathaniel Smith @njsmith + Shaden Smith @ShadenSmith + Tyler Smith @tlrmchlsmth (The University of Texas at Austin) + Snehith @ArcadioN09 + Paul Springer @springer13 (RWTH Aachen University) + Adam J. Stewart @adamjstewart (University of Illinois at Urbana-Champaign) Vladimir Sukarev - Santanu Thangaraj (AMD) - Rhys Ulerich @RhysU (The University of Texas at Austin) - Robert van de Geijn @rvdg (The University of Texas at Austin) - Kiran Varaganti @kvaragan (AMD) - Natalia Vassilieva (Hewlett Packard Enterprise) - Zhang Xianyi @xianyi (Chinese Academy of Sciences) - Benda Xu @heroxbd - Costas Yamin @cosstas - Chenhan Yu @ChenhanYu (The University of Texas at Austin) - M. Zhou @cdluminate + Chengguo Sun @chengguosun + Santanu Thangaraj (AMD) + Nicholai Tukanov @nicholaiTukanov (The University of Texas at Austin) + Rhys Ulerich @RhysU (The University of Texas at Austin) + Robert van de Geijn @rvdg (The University of Texas at Austin) + Meghana Vankadari @Meghana-vankadari (AMD) + Kiran Varaganti @kvaragan (AMD) + Natalia Vassilieva (Hewlett Packard Enterprise) + @h-vetinari + Andrew Wildman @awild82 (University of Washington) + Zhang Xianyi @xianyi (Chinese Academy of Sciences) + Benda Xu @heroxbd + Guodong Xu @docularxu (Linaro.org) + RuQing Xu @xrq-phys (The University of Tokyo) + Costas Yamin @cosstas + Chenhan Yu @ChenhanYu (The University of Texas at Austin) + Roman Yurchak @rth (Symerio) + Stefano Zampini @stefanozampini + M. Zhou @cdluminate BLIS's development was partially funded by grants from industry partners, including diff --git a/LICENSE b/LICENSE index 1c68453758..b9cde54b85 100644 --- a/LICENSE +++ b/LICENSE @@ -15,7 +15,7 @@ copyright info. All parties provide their portions of the code under the Copyright (C) 2018, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP -Copyright (C) 2018, Advanced Micro Devices, Inc. +Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/Makefile b/Makefile index 2d31feeb04..5605dd8fc3 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2022, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -36,7 +37,7 @@ # Makefile # # Field G. Van Zee -# +# # Top-level makefile for libflame linear algebra library. # # @@ -114,6 +115,7 @@ BASE_OBJ_CONFIG_PATH := $(BASE_OBJ_PATH)/$(CONFIG_DIR) BASE_OBJ_FRAME_PATH := $(BASE_OBJ_PATH)/$(FRAME_DIR) BASE_OBJ_REFKERN_PATH := $(BASE_OBJ_PATH)/$(REFKERN_DIR) BASE_OBJ_KERNELS_PATH := $(BASE_OBJ_PATH)/$(KERNELS_DIR) +BASE_OBJ_ADDON_PATH := $(BASE_OBJ_PATH)/$(ADDON_DIR) BASE_OBJ_SANDBOX_PATH := $(BASE_OBJ_PATH)/$(SANDBOX_DIR) # --- Define install target names for static libraries --- @@ -166,6 +168,7 @@ MK_INCL_DIR_INST := $(INSTALL_INCDIR)/blis # Set the path to the subdirectory of the share installation directory. MK_SHARE_DIR_INST := $(INSTALL_SHAREDIR)/blis +PC_SHARE_DIR_INST := $(INSTALL_SHAREDIR)/pkgconfig # @@ -209,15 +212,42 @@ MK_REFKERN_OBJS := $(foreach arch, $(CONFIG_LIST), \ # Generate object file paths for all of the portable framework source code. MK_FRAME_OBJS := $(call gen-obj-paths-from-src,$(FRAME_SRC_SUFS),$(MK_FRAME_SRC),$(FRAME_PATH),$(BASE_OBJ_FRAME_PATH)) +# Generate object file paths for the addon source code. If one or more addons +# were not enabled a configure-time, this variable will we empty. +MK_ADDON_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) + # Generate object file paths for the sandbox source code. If a sandbox was not # enabled a configure-time, this variable will we empty. MK_SANDBOX_OBJS := $(call gen-obj-paths-from-src,$(SANDBOX_SRC_SUFS),$(MK_SANDBOX_SRC),$(SANDBOX_PATH),$(BASE_OBJ_SANDBOX_PATH)) +# AMD has chosen to introduce AOCL-specific optimizations to certain BLIS +# framework files that are otherwise intended to remain generic. Upstream +# developers of vanilla BLIS have agreed to integrate some of these +# optimizations, but in a way that keeps the AOCL-specific code segregated +# in separate files containing the suffix '_amd'. For example, the BLAS +# compatibility layer in vanilla BLIS contains a generic file named +# 'bla_gemm.c'. AMD's version of this file is named 'bla_gemm_amd.c'. +# Only one or the other is ever built and included in libblis. Currently, +# these files are chosen automatically based on the target configuration. +ifeq ($(ENABLE_AMD_FRAME_TWEAKS),yes) +# Build is being done for AMD platforms; remove the objects which DO NOT have +# an "_amd" suffix. +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +FILES_TO_REMOVE := $(subst _amd.o,.o, $(MK_FRAME_AMD_OBJS)) +MK_FRAME_OBJS := $(filter-out $(FILES_TO_REMOVE), $(MK_FRAME_OBJS)) +else +# Build is being done for non-AMD platforms; remove the objects which DO have +# an "_amd" suffix. +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +MK_FRAME_OBJS := $(filter-out $(MK_FRAME_AMD_OBJS), $(MK_FRAME_OBJS)) +endif + # Combine all of the object files into some readily-accessible variables. MK_BLIS_OBJS := $(MK_CONFIG_OBJS) \ $(MK_KERNELS_OBJS) \ $(MK_REFKERN_OBJS) \ $(MK_FRAME_OBJS) \ + $(MK_ADDON_OBJS) \ $(MK_SANDBOX_OBJS) # Optionally filter out the BLAS and CBLAS compatibility layer object files. @@ -249,6 +279,12 @@ ifeq ($(MK_ENABLE_CBLAS),yes) HEADERS_TO_INSTALL += $(CBLAS_H_FLAT) endif +# If requested, include AMD's C++ template header files in the list of headers +# to install. +ifeq ($(INSTALL_HH),yes) +HEADERS_TO_INSTALL += $(wildcard $(VEND_CPP_PATH)/*.hh) +endif + # @@ -259,6 +295,8 @@ endif FRAGS_TO_INSTALL := $(CONFIG_MK_FILE) \ $(COMMON_MK_FILE) +PC_IN_FILE := blis.pc.in +PC_OUT_FILE := blis.pc # @@ -385,23 +423,22 @@ ifeq ($(IS_CONFIGURED),yes) # named with three .so version numbers. UNINSTALL_OLD_LIBS := -UNINSTALL_OLD_LIBS += $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS_SO).?.?.?" 2> /dev/null | $(GREP) -v "$(LIBBLIS).$(LIBBLIS_SO_MMB_EXT)") +UNINSTALL_OLD_LIBS += $(filter-out $(INSTALL_LIBDIR)/$(LIBBLIS).$(LIBBLIS_SO_MMB_EXT),$(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS_SO).?.?.?)) # These shell commands gather the filepaths to any library symlink in the # current LIBDIR that might be left over from an old installation. We start # with symlinks named using the .so major version number. -UNINSTALL_OLD_SYML := $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS_SO).?" 2> /dev/null | $(GREP) -v "$(LIBBLIS_SO).$(SO_MAJOR)") +UNINSTALL_OLD_SYML := $(filter-out $(INSTALL_LIBDIR)/$(LIBBLIS_SO).$(SO_MAJOR),$(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS_SO).?)) # We also prepare to uninstall older-style symlinks whose names contain the # BLIS version number and configuration family. -UNINSTALL_OLD_SYML += $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS)-*.a" 2> /dev/null | $(GREP) -v "$(LIBBLIS)-$(VERS_CONF).a") - -UNINSTALL_OLD_SYML += $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS)-*.$(SHLIB_EXT)" 2> /dev/null | $(GREP) -v "$(LIBBLIS)-$(VERS_CONF).$(SHLIB_EXT)") +UNINSTALL_OLD_SYML += $(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS)-*.a) +UNINSTALL_OLD_SYML += $(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS)-*.$(SHLIB_EXT)) # This shell command grabs all files named "*.h" that are not blis.h or cblas.h # in the installation directory. We consider this set of headers to be "old" and # eligible for removal upon running of the uninstall-old-headers target. -UNINSTALL_OLD_HEADERS := $(shell $(FIND) $(INSTALL_INCDIR)/blis/ -name "*.h" 2> /dev/null | $(GREP) -v "$(BLIS_H)" | $(GREP) -v "$(CBLAS_H)") +UNINSTALL_OLD_HEADERS := $(filter-out $(BLIS_H),$(filter-out $(CBLAS_H),$(wildcard $(INSTALL_INCDIR)/blis/*.h))) endif # IS_CONFIGURED @@ -453,7 +490,7 @@ endif flat-header: check-env $(BLIS_H_FLAT) -$(BLIS_H_FLAT): $(FRAME_H99_FILES) +$(BLIS_H_FLAT): $(ALL_H99_FILES) ifeq ($(ENABLE_VERBOSE),yes) $(FLATTEN_H) -c -v1 $(BLIS_H_SRC_PATH) $@ "./$(INCLUDE_DIR)" "$(ALL_H99_DIRPATHS)" else @@ -541,6 +578,28 @@ else endif endef +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +define make-c99-addon-rule +$(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) + @$(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +endif +endef + +define make-cxx-addon-rule +$(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_HXX_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CXX) $(call get-addon-cxxflags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-cxxtext-for,$(1)) + @$(CXX) $(call get-addon-cxxflags-for,$(1)) -c $$< -o $$@ +endif +endef + # first argument: a configuration name from the union of config_list and # config_name, used to look up the CFLAGS to use during compilation. define make-c99-sandbox-rule @@ -593,6 +652,16 @@ $(foreach conf, $(CONFIG_LIST), $(eval $(call make-refkern-rule,$(conf)))) $(foreach suf, $(KERNELS_SRC_SUFS), \ $(foreach kset, $(KERNEL_LIST), $(eval $(call make-kernels-rule,$(kset),$(call get-config-for-kset,$(kset)),$(suf))))) +# Instantiate the build rule for C addon files. Use the CFLAGS for the +# configuration family. +$(foreach suf, $(ADDON_C99_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-c99-addon-rule,$(conf),$(suf))))) + +# Instantiate the build rule for C++ addon files. Use the CFLAGS for the +# configuration family. +$(foreach suf, $(ADDON_CXX_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-cxx-addon-rule,$(conf),$(suf))))) + # Instantiate the build rule for C sandbox files. Use the CFLAGS for the # configuration family. $(foreach suf, $(SANDBOX_C99_SUFS), \ @@ -646,7 +715,7 @@ ifeq ($(ARG_MAX_HACK),yes) $(LINKER) $(SOFLAGS) -o $(LIBBLIS_SO_OUTPUT_NAME) @$@.in $(LDFLAGS) $(RM_F) $@.in else - $(LINKER) $(SOFLAGS) -o $(LIBBLIS_SO_OUTPUT_NAME) $? $(LDFLAGS) + $(LINKER) $(SOFLAGS) -o $(LIBBLIS_SO_OUTPUT_NAME) $^ $(LDFLAGS) endif else # ifeq ($(ENABLE_VERBOSE),no) ifeq ($(ARG_MAX_HACK),yes) @@ -656,7 +725,7 @@ ifeq ($(ARG_MAX_HACK),yes) @$(RM_F) $@.in else @echo "Dynamically linking $@" - @$(LINKER) $(SOFLAGS) -o $(LIBBLIS_SO_OUTPUT_NAME) $? $(LDFLAGS) + @$(LINKER) $(SOFLAGS) -o $(LIBBLIS_SO_OUTPUT_NAME) $^ $(LDFLAGS) endif endif @@ -680,7 +749,7 @@ endif # --- BLAS test suite rules --- -testblas: blastest-run +testblas: blastest-run blastest-f2c: check-env $(BLASTEST_F2C_LIB) @@ -689,7 +758,7 @@ blastest-bin: check-env blastest-f2c $(BLASTEST_DRV_BIN_PATHS) blastest-run: $(BLASTEST_DRV_BINS_R) # f2c object file rule. -$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_F2C_SRC_PATH)/%.c +$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_F2C_SRC_PATH)/%.c $(BLIS_H_FLAT) ifeq ($(ENABLE_VERBOSE),yes) $(CC) $(call get-user-cflags-for,$(CONFIG_NAME)) $(BLAT_CFLAGS) -c $< -o $@ else @@ -698,7 +767,7 @@ else endif # driver object file rule. -$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_DRV_SRC_PATH)/%.c +$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_DRV_SRC_PATH)/%.c $(BLIS_H_FLAT) ifeq ($(ENABLE_VERBOSE),yes) $(CC) $(call get-user-cflags-for,$(CONFIG_NAME)) $(BLAT_CFLAGS) -c $< -o $@ else @@ -717,19 +786,13 @@ else @$(RANLIB) $@ endif -# first argument: the base name of the BLAS test driver. -define make-blat-rule -$(BASE_OBJ_BLASTEST_PATH)/$(1).x: $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) +$(BASE_OBJ_BLASTEST_PATH)/%.x: $(BASE_OBJ_BLASTEST_PATH)/%.o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) ifeq ($(ENABLE_VERBOSE),yes) - $(LINKER) $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $$@ + $(LINKER) $< $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ else - @echo "Linking $$(@F) against '$(notdir $(BLASTEST_F2C_LIB)) $(LIBBLIS_LINK) $(LDFLAGS)'" - @$(LINKER) $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $$@ + @echo "Linking $@ against '$(notdir $(BLASTEST_F2C_LIB)) $(LIBBLIS_LINK) "$(LDFLAGS)"'" + @$(LINKER) $< $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ endif -endef - -# Instantiate the rule above for each driver file. -$(foreach name, $(BLASTEST_DRV_BASES), $(eval $(call make-blat-rule,$(name)))) # A rule to run ?blat1.x driver files. define make-run-blat1-rule @@ -785,7 +848,7 @@ testsuite: testsuite-run testsuite-bin: check-env $(TESTSUITE_BIN) # Object file rule. -$(BASE_OBJ_TESTSUITE_PATH)/%.o: $(TESTSUITE_SRC_PATH)/%.c +$(BASE_OBJ_TESTSUITE_PATH)/%.o: $(TESTSUITE_SRC_PATH)/%.c $(BLIS_H_FLAT) ifeq ($(ENABLE_VERBOSE),yes) $(CC) $(call get-user-cflags-for,$(CONFIG_NAME)) -c $< -o $@ else @@ -798,7 +861,7 @@ $(TESTSUITE_BIN): $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) ifeq ($(ENABLE_VERBOSE),yes) $(LINKER) $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ else - @echo "Linking $@ against '$(LIBBLIS_LINK) $(LDFLAGS)'" + @echo "Linking $@ against '$(LIBBLIS_LINK) "$(LDFLAGS)"'" @$(LINKER) $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ endif @@ -893,6 +956,19 @@ else @- $(TESTSUITE_CHECK_PATH) $(TESTSUITE_OUT_FILE) endif + +# --- AMD's C++ template header test rules --- + +# NOTE: The targets below won't work as intended for an out-of-tree build, +# and so it's disabled for now. + +#testcpp: testvendcpp + +# Recursively run the test for AMD's C++ template header. +#testvendcpp: +# $(MAKE) -C $(VEND_TESTCPP_PATH) + + # --- Install header rules --- install-headers: check-env $(MK_INCL_DIR_INST) @@ -910,7 +986,7 @@ endif # --- Install share rules --- -install-share: check-env $(MK_SHARE_DIR_INST) +install-share: check-env $(MK_SHARE_DIR_INST) $(PC_SHARE_DIR_INST) $(MK_SHARE_DIR_INST): $(FRAGS_TO_INSTALL) $(CONFIG_MK_FILE) ifeq ($(ENABLE_VERBOSE),yes) @@ -929,6 +1005,20 @@ else $(@)/$(CONFIG_DIR)/$(CONFIG_NAME)/ endif +$(PC_SHARE_DIR_INST): $(PC_IN_FILE) + $(MKDIR) $(@) +ifeq ($(ENABLE_VERBOSE),no) + @echo "Installing $(PC_OUT_FILE) into $(@)/" +endif + $(shell cat "$(PC_IN_FILE)" \ + | sed -e "s#@PACKAGE_VERSION@#$(VERSION)#g" \ + | sed -e "s#@prefix@#$(prefix)#g" \ + | sed -e "s#@exec_prefix@#$(exec_prefix)#g" \ + | sed -e "s#@libdir@#$(libdir)#g" \ + | sed -e "s#@includedir@#$(includedir)#g" \ + | sed -e "s#@LDFLAGS@#$(LDFLAGS)#g" \ + > "$(PC_OUT_FILE)" ) + $(INSTALL) -m 0644 $(PC_OUT_FILE) $(@) # --- Install library rules --- @@ -954,11 +1044,11 @@ ifeq ($(IS_WIN),no) $(INSTALL_LIBDIR)/%.$(LIBBLIS_SO_MMB_EXT): $(BASE_LIB_PATH)/%.$(SHLIB_EXT) $(CONFIG_MK_FILE) ifeq ($(ENABLE_VERBOSE),yes) $(MKDIR) $(@D) - $(INSTALL) -m 0644 $< $@ + $(INSTALL) -m 0755 $< $@ else @echo "Installing $(@F) into $(INSTALL_LIBDIR)/" @$(MKDIR) $(@D) - @$(INSTALL) -m 0644 $< $@ + @$(INSTALL) -m 0755 $< $@ endif else # ifeq ($(IS_WIN),yes) @@ -1020,23 +1110,24 @@ endif # ifeq ($(IS_WIN),no) # --- Query current configuration --- showconfig: check-env - @echo "configuration family: $(CONFIG_NAME)" - @echo "sub-configurations: $(CONFIG_LIST)" - @echo "requisite kernels: $(KERNEL_LIST)" - @echo "kernel-to-config map: $(KCONFIG_MAP)" - @echo "-----------------------" - @echo "BLIS version string: $(VERSION)" - @echo ".so major version: $(SO_MAJOR)" - @echo ".so minor.build vers: $(SO_MINORB)" - @echo "install libdir: $(INSTALL_LIBDIR)" - @echo "install includedir: $(INSTALL_INCDIR)" - @echo "debugging status: $(DEBUG_TYPE)" - @echo "multithreading status: $(THREADING_MODEL)" - @echo "enable BLAS API? $(MK_ENABLE_BLAS)" - @echo "enable CBLAS API? $(MK_ENABLE_CBLAS)" - @echo "build static library? $(MK_ENABLE_STATIC)" - @echo "build shared library? $(MK_ENABLE_SHARED)" - @echo "ARG_MAX hack enabled? $(ARG_MAX_HACK)" + @echo "configuration family: $(CONFIG_NAME)" + @echo "sub-configurations: $(CONFIG_LIST)" + @echo "requisite kernels sets: $(KERNEL_LIST)" + @echo "kernel-to-config map: $(KCONFIG_MAP)" + @echo "-------------------------" + @echo "BLIS version string: $(VERSION)" + @echo ".so major version: $(SO_MAJOR)" + @echo ".so minor.build vers: $(SO_MINORB)" + @echo "install libdir: $(INSTALL_LIBDIR)" + @echo "install includedir: $(INSTALL_INCDIR)" + @echo "install sharedir: $(INSTALL_SHAREDIR)" + @echo "debugging status: $(DEBUG_TYPE)" + @echo "multithreading status: $(THREADING_MODEL)" + @echo "enable BLAS API? $(MK_ENABLE_BLAS)" + @echo "enable CBLAS API? $(MK_ENABLE_CBLAS)" + @echo "build static library? $(MK_ENABLE_STATIC)" + @echo "build shared library? $(MK_ENABLE_SHARED)" + @echo "ARG_MAX hack enabled? $(ARG_MAX_HACK)" # --- Clean rules --- @@ -1048,20 +1139,27 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(FIND) $(FRAME_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +ifneq ($(ADDON_LIST),) + - $(FIND) $(ADDON_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +endif ifneq ($(SANDBOX),) - $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) endif else - @echo "Removing makefile fragments from $(CONFIG_FRAG_PATH)." + @echo "Removing makefile fragments from $(CONFIG_FRAG_PATH)" @- $(FIND) $(CONFIG_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - @echo "Removing makefile fragments from $(FRAME_FRAG_PATH)." + @echo "Removing makefile fragments from $(FRAME_FRAG_PATH)" @- $(FIND) $(FRAME_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - @echo "Removing makefile fragments from $(REFKERN_FRAG_PATH)." + @echo "Removing makefile fragments from $(REFKERN_FRAG_PATH)" @- $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - @echo "Removing makefile fragments from $(KERNELS_FRAG_PATH)." + @echo "Removing makefile fragments from $(KERNELS_FRAG_PATH)" @- $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +ifneq ($(ADDON_LIST),) + @echo "Removing makefile fragments from $(ADDON_FRAG_PATH)" + @- $(FIND) $(ADDON_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +endif ifneq ($(SANDBOX),) - @echo "Removing makefile fragments from $(SANDBOX_FRAG_PATH)." + @echo "Removing makefile fragments from $(SANDBOX_FRAG_PATH)" @- $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) endif endif @@ -1073,7 +1171,7 @@ ifeq ($(ENABLE_VERBOSE),yes) $(RM_F) $(BLIS_H_FLAT) $(RM_F) $(CBLAS_H_FLAT) else - @echo "Removing flattened header files from $(BASE_INC_PATH)." + @echo "Removing flattened header files from $(BASE_INC_PATH)" @$(RM_F) $(BLIS_H_FLAT) @$(RM_F) $(CBLAS_H_FLAT) endif @@ -1086,9 +1184,9 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(LIBBLIS_A_PATH) - $(RM_F) $(LIBBLIS_SO_PATH) else - @echo "Removing object files from $(BASE_OBJ_PATH)." + @echo "Removing object files from $(BASE_OBJ_PATH)" @- $(FIND) $(BASE_OBJ_PATH) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing libraries from $(BASE_LIB_PATH)." + @echo "Removing libraries from $(BASE_LIB_PATH)" @- $(RM_F) $(LIBBLIS_A_PATH) @- $(RM_F) $(LIBBLIS_SO_PATH) endif @@ -1110,13 +1208,13 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(BLASTEST_DRV_BIN_PATHS) - $(RM_F) $(addprefix out.,$(BLASTEST_DRV_BASES)) else - @echo "Removing object files from $(BASE_OBJ_BLASTEST_PATH)." + @echo "Removing object files from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_F2C_OBJS) $(BLASTEST_DRV_OBJS) - @echo "Removing libf2c.a from $(BASE_OBJ_BLASTEST_PATH)." + @echo "Removing libf2c.a from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_F2C_LIB) - @echo "Removing binaries from $(BASE_OBJ_BLASTEST_PATH)." + @echo "Removing binaries from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_DRV_BIN_PATHS) - @echo "Removing driver output files 'out.*'." + @echo "Removing driver output files 'out.*'" @- $(RM_F) $(addprefix out.,$(BLASTEST_DRV_BASES)) endif # ENABLE_VERBOSE endif # IS_CONFIGURED @@ -1129,13 +1227,13 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(BLASTEST_DIR)/$(BLASTEST_F2C_LIB_NAME) - $(RM_F) $(addprefix $(BLASTEST_DIR)/out.,$(BLASTEST_DRV_BASES)) else - @echo "Removing object files from ./$(BLASTEST_DIR)/$(OBJ_DIR)." + @echo "Removing object files from ./$(BLASTEST_DIR)/$(OBJ_DIR)" @- $(FIND) $(BLASTEST_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing libf2c.a from ./$(BLASTEST_DIR)." + @echo "Removing libf2c.a from ./$(BLASTEST_DIR)" @- $(RM_F) $(BLASTEST_DIR)/$(BLASTEST_F2C_LIB_NAME) - @echo "Removing binaries from ./$(BLASTEST_DIR)." + @echo "Removing binaries from ./$(BLASTEST_DIR)" @- $(FIND) $(BLASTEST_DIR) -name "*.x" | $(XARGS) $(RM_F) - @echo "Removing driver output files 'out.*' from ./$(BLASTEST_DIR)." + @echo "Removing driver output files 'out.*' from ./$(BLASTEST_DIR)" @- $(RM_F) $(addprefix $(BLASTEST_DIR)/out.,$(BLASTEST_DRV_BASES)) endif # ENABLE_VERBOSE endif # IS_CONFIGURED @@ -1153,11 +1251,11 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(TESTSUITE_BIN) - $(RM_F) $(TESTSUITE_OUT_FILE) else - @echo "Removing object files from $(BASE_OBJ_TESTSUITE_PATH)." + @echo "Removing object files from $(BASE_OBJ_TESTSUITE_PATH)" @- $(RM_F) $(MK_TESTSUITE_OBJS) - @echo "Removing binary $(TESTSUITE_BIN)." + @echo "Removing binary $(TESTSUITE_BIN)" @- $(RM_F) $(TESTSUITE_BIN) - @echo "Removing $(TESTSUITE_OUT_FILE)." + @echo "Removing $(TESTSUITE_OUT_FILE)" @- $(RM_F) $(TESTSUITE_OUT_FILE) endif # ENABLE_VERBOSE endif # IS_CONFIGURED @@ -1167,32 +1265,40 @@ ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) - $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) +# - $(MAKE) -C $(VEND_TESTCPP_DIR) clean else - @echo "Removing object files from $(TESTSUITE_DIR)/$(OBJ_DIR)." + @echo "Removing object files from $(TESTSUITE_DIR)/$(OBJ_DIR)" @- $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing binary $(TESTSUITE_DIR)/$(TESTSUITE_BIN)." + @echo "Removing binary $(TESTSUITE_DIR)/$(TESTSUITE_BIN)" @- $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) +# @$(MAKE) -C $(VEND_TESTCPP_DIR) clean endif # ENABLE_VERBOSE endif # IS_CONFIGURED distclean: cleanmk cleanh cleanlib cleantest ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) + - $(RM_F) $(BLIS_ADDON_H) - $(RM_F) $(BLIS_CONFIG_H) - $(RM_F) $(CONFIG_MK_FILE) + - $(RM_F) $(PC_OUT_FILE) - $(RM_RF) $(OBJ_DIR) - $(RM_RF) $(LIB_DIR) - $(RM_RF) $(INCLUDE_DIR) else - @echo "Removing $(BLIS_CONFIG_H)." + @echo "Removing $(BLIS_ADDON_H)" + @$(RM_F) $(BLIS_ADDON_H) + @echo "Removing $(BLIS_CONFIG_H)" @$(RM_F) $(BLIS_CONFIG_H) - @echo "Removing $(CONFIG_MK_FILE)." + @echo "Removing $(CONFIG_MK_FILE)" @- $(RM_F) $(CONFIG_MK_FILE) - @echo "Removing $(OBJ_DIR)." + @echo "Removing $(PC_OUT_FILE)" + @- $(RM_F) $(PC_OUT_FILE) + @echo "Removing $(OBJ_DIR)" @- $(RM_RF) $(OBJ_DIR) - @echo "Removing $(LIB_DIR)." + @echo "Removing $(LIB_DIR)" @- $(RM_RF) $(LIB_DIR) - @echo "Removing $(INCLUDE_DIR)." + @echo "Removing $(INCLUDE_DIR)" @- $(RM_RF) $(INCLUDE_DIR) endif endif @@ -1201,8 +1307,8 @@ endif # --- CHANGELOG rules --- changelog: - @echo "Updating '$(DIST_PATH)/$(CHANGELOG)' via '$(GIT_LOG)'." - @$(GIT_LOG) > $(DIST_PATH)/$(CHANGELOG) + @echo "Updating '$(DIST_PATH)/$(CHANGELOG)' via '$(GIT_LOG)'" + @$(GIT_LOG) > $(DIST_PATH)/$(CHANGELOG) # --- Uninstall rules --- @@ -1216,7 +1322,7 @@ uninstall-libs: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(MK_LIBS_INST) else - @echo "Uninstalling libraries $(notdir $(MK_LIBS_INST)) from $(dir $(firstword $(MK_LIBS_INST)))." + @echo "Uninstalling libraries $(notdir $(MK_LIBS_INST)) from $(dir $(firstword $(MK_LIBS_INST)))" @- $(RM_F) $(MK_LIBS_INST) endif @@ -1224,7 +1330,7 @@ uninstall-lib-symlinks: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(MK_LIBS_SYML) else - @echo "Uninstalling symlinks $(notdir $(MK_LIBS_SYML)) from $(dir $(firstword $(MK_LIBS_SYML)))." + @echo "Uninstalling symlinks $(notdir $(MK_LIBS_SYML)) from $(dir $(firstword $(MK_LIBS_SYML)))" @- $(RM_F) $(MK_LIBS_SYML) endif @@ -1232,7 +1338,7 @@ uninstall-headers: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(MK_INCL_DIR_INST) else - @echo "Uninstalling directory '$(notdir $(MK_INCL_DIR_INST))' from $(dir $(MK_INCL_DIR_INST))." + @echo "Uninstalling directory '$(notdir $(MK_INCL_DIR_INST))' from $(dir $(MK_INCL_DIR_INST))" @- $(RM_RF) $(MK_INCL_DIR_INST) endif @@ -1240,7 +1346,7 @@ uninstall-share: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(MK_SHARE_DIR_INST) else - @echo "Uninstalling directory '$(notdir $(MK_SHARE_DIR_INST))' from $(dir $(MK_SHARE_DIR_INST))." + @echo "Uninstalling directory '$(notdir $(MK_SHARE_DIR_INST))' from $(dir $(MK_SHARE_DIR_INST))" @- $(RM_RF) $(MK_SHARE_DIR_INST) endif @@ -1256,7 +1362,7 @@ $(UNINSTALL_OLD_LIBS) $(UNINSTALL_OLD_SYML) $(UNINSTALL_OLD_HEADERS): check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $@ else - @echo "Uninstalling $(@F) from $(@D)/." + @echo "Uninstalling $(@F) from $(@D)/" @- $(RM_F) $@ endif diff --git a/README.md b/README.md index 13acd96ec2..211ebd6d52 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,22 @@ ![The BLIS cat is sleeping.](http://www.cs.utexas.edu/users/field/blis_cat.png) -[![Build Status](https://travis-ci.org/flame/blis.svg?branch=master)](https://travis-ci.org/flame/blis) +[![Build Status](https://api.travis-ci.com/flame/blis.svg?branch=master)](https://app.travis-ci.com/github/flame/blis) +[![Build Status](https://ci.appveyor.com/api/projects/status/github/flame/blis?branch=master&svg=true)](https://ci.appveyor.com/project/shpc/blis/branch/master) Contents -------- * **[Introduction](#introduction)** +* **[Education and Learning](#education-and-learning)** * **[What's New](#whats-new)** * **[What People Are Saying About BLIS](#what-people-are-saying-about-blis)** * **[Key Features](#key-features)** +* **[How to Download BLIS](#how-to-download-blis)** * **[Getting Started](#getting-started)** +* **[Example Code](#example-code)** * **[Documentation](#documentation)** -* **[External GNU/Linux Packages](#external-gnulinux-packages)** +* **[Performance](#performance)** +* **[External Packages](#external-packages)** * **[Discussion](#discussion)** * **[Contributing](#contributing)** * **[Citations](#citations)** @@ -61,12 +66,12 @@ compared to conventional approaches to developing BLAS libraries, as well as a much-needed refinement of the BLAS interface, and thus constitutes a major advance in dense linear algebra computation. While BLIS remains a work-in-progress, we are excited to continue its development and further -cultivate its use within the community. +cultivate its use within the community. The BLIS framework is primarily developed and maintained by individuals in the [Science of High-Performance Computing](http://shpc.ices.utexas.edu/) (SHPC) group in the -[Institute for Computational Engineering and Sciences](https://www.ices.utexas.edu/) +[Oden Institute for Computational Engineering and Sciences](https://www.oden.utexas.edu/) at [The University of Texas at Austin](https://www.utexas.edu/). Please visit the [SHPC](http://shpc.ices.utexas.edu/) website for more information about our research group, such as a list of @@ -76,9 +81,77 @@ and [collaborators](http://shpc.ices.utexas.edu/collaborators.html), [publications](http://shpc.ices.utexas.edu/publications.html), and [other educational projects](http://www.ulaff.net/) (such as MOOCs). +Education and Learning +---------------------- + +Want to understand what's under the hood? +Many of the same concepts and principles employed when developing BLIS are +introduced and taught in a basic pedagogical setting as part of +[LAFF-On Programming for High Performance (LAFF-On-PfHP)](http://www.ulaff.net/), +one of several massive open online courses (MOOCs) in the +[Linear Algebra: Foundations to Frontiers](http://www.ulaff.net/) series, +all of which are available for free via the [edX platform](http://www.edx.org/). + What's New ---------- + * **Addons feature now available!** Have you ever wanted to quickly extend BLIS's +operation support or define new custom BLIS APIs for your application, but were +unsure of how to add your source code to BLIS? Do you want to isolate your custom +code so that it only gets enabled when the user requests it? Do you like +[sandboxes](docs/Sandboxes.md), but wish you didn't have to provide an +implementation of `gemm`? If so, you should check out our new +[addons](docs/Addons.md) feature. Addons act like optional extensions that can be +created, enabled, and combined to suit your application's needs, all without +formally integrating your code into the core BLIS framework. + + * **Multithreaded small/skinny matrix support for sgemm now available!** Thanks to +funding and hardware support from Oracle, we have now accelerated `gemm` for +single-precision real matrix problems where one or two dimensions is exceedingly +small. This work is similar to the `gemm` optimization announced last year. +For now, we have only gathered performance results on an AMD Epyc Zen2 system, but +we hope to publish additional graphs for other architectures in the future. You may +find these Zen2 graphs via the [PerformanceSmall](docs/PerformanceSmall.md) document. + + * **BLIS awarded SIAM Activity Group on Supercomputing Best Paper Prize for 2020!** +We are thrilled to announce that the paper that we internally refer to as the +second BLIS paper, + + "The BLIS Framework: Experiments in Portability." Field G. Van Zee, Tyler Smith, Bryan Marker, Tze Meng Low, Robert A. van de Geijn, Francisco Igual, Mikhail Smelyanskiy, Xianyi Zhang, Michael Kistler, Vernon Austel, John A. Gunnels, Lee Killough. ACM Transactions on Mathematical Software (TOMS), 42(2):12:1--12:19, 2016. + + was selected for the [SIAM Activity Group on Supercomputing Best Paper Prize](https://www.siam.org/prizes-recognition/activity-group-prizes/detail/siag-sc-best-paper-prize) +for 2020. The prize is awarded once every two years to a paper judged to be +the most outstanding paper in the field of parallel scientific and engineering +computing, and has only been awarded once before (in 2016) since its inception +in 2015 (the committee did not award the prize in 2018). The prize +[was awarded](https://www.oden.utexas.edu/about/news/ScienceHighPerfomanceComputingSIAMBestPaperPrize/) +at the [2020 SIAM Conference on Parallel Processing for Scientific Computing](https://www.siam.org/conferences/cm/conference/pp20) in Seattle. Robert was present at +the conference to give +[a talk on BLIS](https://meetings.siam.org/sess/dsp_programsess.cfm?SESSIONCODE=68266) and accept the prize alongside other coauthors. +The selection committee sought to recognize the paper, "which validates BLIS, +a framework relying on the notion of microkernels that enables both productivity +and high performance." Their statement continues, "The framework will continue +having an important influence on the design and the instantiation of dense linear +algebra libraries." + + * **Multithreaded small/skinny matrix support for dgemm now available!** Thanks to +contributions made possible by our partnership with AMD, we have dramatically +accelerated `gemm` for double-precision real matrix problems where one or two +dimensions is exceedingly small. A natural byproduct of this optimization is +that the traditional case of small _m = n = k_ (i.e. square matrices) is also +accelerated, even though it was not targeted specifically. And though only +`dgemm` was optimized for now, support for other datatypes and/or other operations +may be implemented in the future. We've also added new graphs to the +[PerformanceSmall](docs/PerformanceSmall.md) document to showcase multithreaded +performance when one or more matrix dimensions are small. + + * **Performance comparisons now available!** We recently measured the +performance of various level-3 operations on a variety of hardware architectures, +as implemented within BLIS and other BLAS libraries for all four of the standard +floating-point datatypes. The results speak for themselves! Check out our +extensive performance graphs and background info in our new +[Performance](docs/Performance.md) document. + * **BLIS is now in Debian Unstable!** Thanks to Debian developer-maintainers [M. Zhou](https://github.com/cdluminate) and [Nico Schlömer](https://github.com/nschloe) for sponsoring our package in Debian. @@ -87,7 +160,7 @@ the second-most popular Linux distribution (behind Ubuntu, which Debian packages feed into). The Debian tracker page may be found [here](https://tracker.debian.org/pkg/blis). - * **BLIS now supports mixed-datatype gemm.** The `gemm` operation may now be + * **BLIS now supports mixed-datatype gemm!** The `gemm` operation may now be executed on operands of mixed domains and/or mixed precisions. Any combination of storage datatype for A, B, and C is now supported, along with a separate computation precision that can differ from the storage precision of A and B. @@ -107,6 +180,9 @@ draft](http://www.cs.utexas.edu/users/flame/pubs/blis6_toms_rev2.pdf)). What People Are Saying About BLIS --------------------------------- +*["I noticed a substantial increase in multithreaded performance on my own +machine, which was extremely satisfying."](https://groups.google.com/d/msg/blis-discuss/8iu9B5KCxpA/uftpjgIsBwAJ)* ... *["[I was] happy it worked so well!"](https://groups.google.com/d/msg/blis-discuss/8iu9B5KCxpA/uftpjgIsBwAJ)* (Justin Shea) + *["This is an awesome library."](https://github.com/flame/blis/issues/288#issuecomment-447488637)* ... *["I want to thank you and the blis team for your efforts."](https://github.com/flame/blis/issues/288#issuecomment-448074704)* ([@Lephar](https://github.com/Lephar)) *["Any time somebody outside Intel beats MKL by a nontrivial amount, I report it to the MKL team. It is fantastic for any open-source project to get within 10% of MKL... [T]his is why Intel funds BLIS development."](https://github.com/flame/blis/issues/264#issuecomment-428673275)* ([@jeffhammond](https://github.com/jeffhammond)) @@ -155,7 +231,7 @@ seeking to implement tensor contractions on multidimensional arrays.) Furthermore, since BLIS tracks stride information for each matrix, operands of different storage formats can be used within the same operation invocation. By contrast, BLAS requires column-major storage. And while the CBLAS interface -supports row-major storage, it does not allow mixing storage formats. +supports row-major storage, it does not allow mixing storage formats. * **Rich support for the complex domain.** BLIS operations are developed and expressed in their most general form, which is typically in the complex domain. @@ -197,22 +273,15 @@ of BLIS's native APIs directly. BLIS's typed API will feel familiar to many veterans of BLAS since these interfaces use BLAS-like calling sequences. And many will find BLIS's object-based APIs a delight to use when customizing or writing their own BLIS operations. (Objects are relatively lightweight -`structs` and passed by address, which helps tame function calling overhead.) +`structs` and passed by address, which helps tame function calling overhead.) - * **Multilayered API, exposed kernels, and sandboxes.** The BLIS framework -exposes its + * **Multilayered API and exposed kernels.** The BLIS framework exposes its implementations in various layers, allowing expert developers to access exactly the functionality desired. This layered interface includes that of the lowest-level kernels, for those who wish to bypass the bulk of the framework. Optimizations can occur at various levels, in part thanks to exposed packing and unpacking facilities, which by default are highly parameterized and -flexible. And more recently, BLIS introduced sandboxes--a way to provide -alternative implementations of `gemm` that do not use any more of the BLIS -infrastructure than is desired. Sandboxes provide a convenient and -straightforward way of modifying the `gemm` implementation without disrupting -any other level-3 operation or any other part of the framework. This works -especially well when the developer wants to experiment with new optimizations -or try a different algorithm. +flexible. * **Functionality that grows with the community's needs.** As its name suggests, the BLIS framework is not a single library or static API, but rather @@ -220,7 +289,9 @@ a nearly-complete template for instantiating high-performance BLAS-like libraries. Furthermore, the framework is extensible, allowing developers to leverage existing components to support new operations as they are identified. If such operations require new kernels for optimal efficiency, the framework -and its APIs will be adjusted and extended accordingly. +and its APIs will be adjusted and extended accordingly. Community developers +who wish to experiment with creating new operations or APIs in BLIS can quickly +and easily do so via the [Addons](docs/Addons.md) feature. * **Code re-use.** Auto-generation approaches to achieving the aforementioned goals tend to quickly lead to code bloat due to the multiple dimensions of @@ -250,9 +321,62 @@ details, please see the documentation on [mixed datatype](docs/MixedDatatypes.md support and/or our [ACM TOMS](https://toms.acm.org/) journal paper on mixed-domain/mixed-precision `gemm` ([linked below](#citations)). +How to Download BLIS +-------------------- + +There are a few ways to download BLIS. We list the most common four ways below. +We **highly recommend** using either Option 1 or 2. Otherwise, we recommend +Option 3 (over Option 4) so your compiler can perform optimizations specific +to your hardware. + +1. **Download a source repository with `git clone`.** +Generally speaking, we prefer using `git clone` to clone a `git` repository. +Having a repository allows the user to periodically pull in the latest changes +and quickly rebuild BLIS whenever they wish. Also, implicit in cloning a +repository is that the repository defaults to using the `master` branch, which +contains the latest "stable" commits since the most recent release. (This is +in contrast to Option 3 in which the user is opting for code that may be +slightly out of date.) + + In order to clone a `git` repository of BLIS, please obtain a repository +URL by clicking on the green button above the file/directory listing near the +top of this page (as rendered by GitHub). Generally speaking, it will amount +to executing the following command in your terminal shell: + ``` + git clone https://github.com/flame/blis.git + ``` + +2. **Download a source repository via a zip file.** +If you are uncomfortable with using `git` but would still like the latest +stable commits, we recommend that you download BLIS as a zip file. + + In order to download a zip file of the BLIS source distribution, please +click on the green button above the file listing near the top of this page. +This should reveal a link for downloading the zip file. + +3. **Download a source release via a tarball/zip file.** +Alternatively, if you would like to stick to the code that is included in +official releases, you may download either a tarball or zip file of any of +BLIS's previous [tagged releases](https://github.com/flame/blis/releases). +We consider this option to be less than ideal for most people since it will +likely mean you miss out on the latest bugfix or feature commits (in contrast +to Options 1 or 2), and you also will not be able to update your code with a +simple `git pull` command (in contrast to Option 1). + +4. **Download a binary package specific to your OS.** +While we don't recommend this as the first choice for most users, we provide +links to community members who generously maintain BLIS packages for various +Linux distributions such as Debian Unstable and EPEL/Fedora. Please see the +[External Packages](#external-packages) section below for more information. + Getting Started --------------- +*NOTE: This section assumes you've either cloned a BLIS source code repository +via `git`, downloaded the latest source code via a zip file, or downloaded the +source code for a tagged version release---Options 1, 2, or 3, respectively, +as discussed in [the previous section](#how-to-download-blis).* + If you just want to build a sequential (not parallelized) version of BLIS in a hurry and come back and explore other topics later, you can configure and build BLIS as follows: @@ -276,6 +400,42 @@ If/when you have time, we *strongly* encourage you to read the detailed walkthrough of the build system found in our [Build System](docs/BuildSystem.md) guide. +Example Code +------------ + +The BLIS source distribution provides example code in the `examples` directory. +Example code focuses on using BLIS APIs (not BLAS or CBLAS), and resides in +two subdirectories: [examples/oapi](examples/oapi) (which demonstrates the +[object API](docs/BLISObjectAPI.md)) and [examples/tapi](examples/tapi) (which +demonstrates the [typed API](docs/BLISTypedAPI.md)). + +Either directory contains several files, each containing various pieces of +code that exercise core functionality of the BLIS API in question (object or +typed). These example files should be thought of collectively like a tutorial, +and therefore it is recommended to start from the beginning (the file that +starts in `00`). + +You can build all of the examples by simply running `make` from either example +subdirectory (`examples/oapi` or `examples/tapi`). (You can also run +`make clean`.) The local `Makefile` assumes that you've already configured and +built (but not necessarily installed) BLIS two directories up, in `../..`. If +you have already installed BLIS to some permanent directory, you may refer to +that installation by setting the environment variable `BLIS_INSTALL_PATH` prior +to running make: +``` +export BLIS_INSTALL_PATH=/usr/local; make +``` +or by setting the same variable as part of the make command: +``` +make BLIS_INSTALL_PATH=/usr/local +``` +**Once the executable files have been built, we recommend reading the code and +the corresponding executable output side by side. This will help you see the +effects of each section of code.** + +This tutorial is not exhaustive or complete; several object API functions were +omitted (mostly for brevity's sake) and thus more examples could be written. + Documentation ------------- @@ -296,16 +456,12 @@ included BLAS test drivers. * **[BLIS Typed API Reference](docs/BLISTypedAPI.md).** Here we document the so-called "typed" (or BLAS-like) API. This is the API that many users who are -already familiar with the BLAS will likely want to use. You can find lots of -example code for the typed API in the [examples/tapi](examples/tapi) directory -included in the BLIS source distribution. +already familiar with the BLAS will likely want to use. * **[BLIS Object API Reference](docs/BLISObjectAPI.md).** Here we document the object API. This is API abstracts away properties of vectors and matrices within `obj_t` structs that can be queried with accessor functions. Many -developers and experts prefer this API over the typed API. You can find lots of -example code for the object API in the [examples/oapi](examples/oapi) directory -included in the BLIS source distribution. +developers and experts prefer this API over the typed API. * **[Hardware Support](docs/HardwareSupport.md).** This document maintains a table of supported microarchitectures. @@ -313,10 +469,20 @@ table of supported microarchitectures. * **[Multithreading](docs/Multithreading.md).** This document describes how to use the multithreading features of BLIS. - * **[Mixed-Datatype](docs/MixedDatatype.md).** This document provides an + * **[Mixed-Datatypes](docs/MixedDatatypes.md).** This document provides an overview of BLIS's mixed-datatype functionality and provides a brief example of how to take advantage of this new code. + * **[Performance](docs/Performance.md).** This document reports empirically +measured performance of a representative set of level-3 operations on a variety +of hardware architectures, as implemented within BLIS and other BLAS libraries +for all four of the standard floating-point datatypes. + + * **[PerformanceSmall](docs/PerformanceSmall.md).** This document reports +empirically measured performance of `gemm` on select hardware architectures +within BLIS and other BLAS libraries when performing matrix problems where one +or two dimensions is exceedingly small. + * **[Release Notes](docs/ReleaseNotes.md).** This document tracks a summary of changes included with each new version of BLIS, along with contributor credits for key features. @@ -326,14 +492,14 @@ about BLIS, please read this FAQ. If you can't find the answer to your question, please feel free to join the [blis-devel](https://groups.google.com/group/blis-devel) mailing list and post a question. We also have a [blis-discuss](https://groups.google.com/group/blis-discuss) mailing list that -anyone can post to (even without joining). +anyone can post to (even without joining). **Documents for github contributors:** * **[Contributing bug reports, feature requests, PRs, etc](CONTRIBUTING.md).** Interested in contributing to BLIS? Please read this document before getting started. It provides a general overview of how best to report bugs, propose new -features, and offer code patches. +features, and offer code patches. * **[Coding Conventions](docs/CodingConventions.md).** If you are interested or planning on contributing code to BLIS, please read this document so that you can @@ -351,12 +517,35 @@ learn how to add new sub-configurations or configuration families, or are simply interested in learning how BLIS organizes its configurations and kernel sets, please read this thorough walkthrough of the configuration system. + * **[Addon Guide](docs/Addons.md).** If you are interested in learning +about using BLIS addons--that is, enabling existing (or creating new) bundles +of operation or API code that are built into a BLIS library--please read this +document. + * **[Sandbox Guide](docs/Sandboxes.md).** If you are interested in learning about using sandboxes in BLIS--that is, providing alternative implementations of the `gemm` operation--please read this document. -External GNU/Linux packages ---------------------------- +Performance +----------- + +We provide graphs that report performance of several implementations across a +range of hardware types, multithreading configurations, problem sizes, +operations, and datatypes. These pages also document most of the details needed +to reproduce these experiments. + + * **[Performance](docs/Performance.md).** This document reports empirically +measured performance of a representative set of level-3 operations on a variety +of hardware architectures, as implemented within BLIS and other BLAS libraries +for all four of the standard floating-point datatypes. + + * **[PerformanceSmall](docs/PerformanceSmall.md).** This document reports +empirically measured performance of `gemm` on select hardware architectures +within BLIS and other BLAS libraries when performing matrix problems where one +or two dimensions is exceedingly small. + +External Packages +----------------- Generally speaking, we **highly recommend** building from source whenever possible using the latest `git` clone. (Tarballs of each @@ -374,6 +563,12 @@ Debian package tracker can be found [here](https://tracker.debian.org/pkg/blis). (Also, thanks to [Nico Schlömer](https://github.com/nschloe) for previously volunteering his time to set up a standalone PPA.) + * **Gentoo**. [M. Zhou](https://github.com/cdluminate) also maintains the +[BLIS package](https://packages.gentoo.org/packages/sci-libs/blis) entry for +[Gentoo](https://www.gentoo.org/), a Linux distribution known for its +source-based [portage](https://wiki.gentoo.org/wiki/Portage) package manager +and distribution system. + * **EPEL/Fedora**. There are official BLIS packages in Fedora and EPEL (for RHEL7+ and compatible distributions) with versions for 64-bit integers, OpenMP, and pthreads, and shims which can be dynamically linked instead of reference @@ -392,6 +587,9 @@ the source rpms may build for others. * **GNU Guix**. Guix has BLIS packages, provides builds only for the generic target and some specific x86_64 micro-architectures. + * **Conda**. conda channel [conda-forge](https://github.com/conda-forge/blis-feedstock) +has Linux, OSX and Windows binary packages for x86_64. + Discussion ---------- @@ -416,7 +614,7 @@ Contributing ------------ For information on how to contribute to our project, including preferred -[coding conventions](docs/CodingConventions), please refer to the +[coding conventions](docs/CodingConventions.md), please refer to the [CONTRIBUTING](CONTRIBUTING.md) file at the top-level of the BLIS source distribution. @@ -425,8 +623,8 @@ Citations For those of you looking for the appropriate article to cite regarding BLIS, we recommend citing our -[first ACM TOMS journal paper](http://dl.acm.org/authorize?N91172) -([unofficial backup link](http://www.cs.utexas.edu/users/flame/pubs/blis1_toms_rev3.pdf)): +[first ACM TOMS journal paper](https://dl.acm.org/doi/10.1145/2764454?cid=81314495332) +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis1_toms_rev3.pdf)): ``` @article{BLIS1, @@ -436,16 +634,16 @@ recommend citing our volume = {41}, number = {3}, pages = {14:1--14:33}, - month = jun, + month = {June}, year = {2015}, issue_date = {June 2015}, - url = {http://doi.acm.org/10.1145/2764454}, + url = {https://doi.acm.org/10.1145/2764454}, } -``` +``` You may also cite the -[second ACM TOMS journal paper](http://dl.acm.org/authorize?N16240) -([unofficial backup link](http://www.cs.utexas.edu/users/flame/pubs/blis2_toms_rev3.pdf)): +[second ACM TOMS journal paper](https://dl.acm.org/doi/10.1145/2755561?cid=81314495332) +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis2_toms_rev3.pdf)): ``` @article{BLIS2, @@ -458,15 +656,16 @@ You may also cite the volume = {42}, number = {2}, pages = {12:1--12:19}, - month = jun, + month = {June}, year = {2016}, issue_date = {June 2016}, - url = {http://doi.acm.org/10.1145/2755561}, + url = {https://doi.acm.org/10.1145/2755561}, } -``` +``` We also have a third paper, submitted to IPDPS 2014, on achieving -[multithreaded parallelism in BLIS](http://www.cs.utexas.edu/users/flame/pubs/blis3_ipdps14.pdf): +[multithreaded parallelism in BLIS](https://dl.acm.org/doi/10.1109/IPDPS.2014.110) +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis3_ipdps14.pdf)): ``` @inproceedings{BLIS3, @@ -475,14 +674,15 @@ We also have a third paper, submitted to IPDPS 2014, on achieving title = {Anatomy of High-Performance Many-Threaded Matrix Multiplication}, booktitle = {28th IEEE International Parallel \& Distributed Processing Symposium (IPDPS 2014)}, - year = 2014, + year = {2014}, + url = {https://doi.org/10.1109/IPDPS.2014.110}, } ``` A fourth paper, submitted to ACM TOMS, also exists, which proposes an -[analytical model](http://dl.acm.org/citation.cfm?id=2925987) -([unofficial backup link](http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf)) -for determining blocksize parameters in BLIS: +[analytical model](https://dl.acm.org/doi/10.1145/2925987) +for determining blocksize parameters in BLIS +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf)): ``` @article{BLIS4, @@ -493,15 +693,16 @@ for determining blocksize parameters in BLIS: volume = {43}, number = {2}, pages = {12:1--12:18}, - month = aug, + month = {August}, year = {2016}, issue_date = {August 2016}, - url = {http://doi.acm.org/10.1145/2925987}, + url = {https://doi.acm.org/10.1145/2925987}, } ``` A fifth paper, submitted to ACM TOMS, begins the study of so-called -[induced methods for complex matrix multiplication](http://www.cs.utexas.edu/users/flame/pubs/blis5_toms_rev2.pdf): +[induced methods for complex matrix multiplication](https://dl.acm.org/doi/10.1145/3086466?cid=81314495332) +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis5_toms_rev2.pdf)): ``` @article{BLIS5, @@ -511,27 +712,36 @@ A fifth paper, submitted to ACM TOMS, begins the study of so-called volume = {44}, number = {1}, pages = {7:1--7:36}, - month = jul, + month = {July}, year = {2017}, issue_date = {July 2017}, - url = {http://doi.acm.org/10.1145/3086466}, + url = {https://doi.acm.org/10.1145/3086466}, } -``` +``` A sixth paper, submitted to ACM TOMS, revisits the topic of the previous -article and derives a [superior induced method](http://www.cs.utexas.edu/users/flame/pubs/blis6_toms_rev2.pdf): +article and derives a +[superior induced method](https://epubs.siam.org/doi/10.1137/19M1282040) +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis6_sisc_rev3.pdf)): ``` @article{BLIS6, author = {Field G. {V}an~{Z}ee}, title = {Implementing High-Performance Complex Matrix Multiplication via the 1m Method}, - journal = {ACM Transactions on Mathematical Software}, - note = {submitted} + journal = {SIAM Journal on Scientific Computing}, + volume = {42}, + number = {5}, + pages = {C221--C244}, + month = {September} + year = {2020}, + issue_date = {September 2020}, + url = {https://doi.org/10.1137/19M1282040} } -``` +``` A seventh paper, submitted to ACM TOMS, explores the implementation of `gemm` for -[mixed-domain and/or mixed-precision](http://www.cs.utexas.edu/users/flame/pubs/blis7_toms_rev0.pdf) operands: +[mixed-domain and/or mixed-precision](https://dl.acm.org/doi/10.1145/3402225?cid=81314495332) operands +([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis7_toms_rev0.pdf)): ``` @article{BLIS7, @@ -539,7 +749,13 @@ A seventh paper, submitted to ACM TOMS, explores the implementation of `gemm` fo title = {Supporting Mixed-domain Mixed-precision Matrix Multiplication within the BLIS Framework}, journal = {ACM Transactions on Mathematical Software}, - note = {submitted} + volume = {47}, + number = {2}, + pages = {12:1--12:26}, + month = {April}, + year = {2021}, + issue_date = {April 2021}, + url = {https://doi.org/10.1145/3402225}, } ``` @@ -547,15 +763,18 @@ Funding ------- This project and its associated research were partially sponsored by grants from -[Microsoft](http://www.microsoft.com/), -[Intel](http://www.intel.com/), -[Texas Instruments](http://www.ti.com/), -[AMD](http://www.amd.com/), -[Oracle](http://www.oracle.com/), +[Microsoft](https://www.microsoft.com/), +[Intel](https://www.intel.com/), +[Texas Instruments](https://www.ti.com/), +[AMD](https://www.amd.com/), +[HPE](https://www.hpe.com/), +[Oracle](https://www.oracle.com/), +[Huawei](https://www.huawei.com/), +[Facebook](https://www.facebook.com/), and -[Huawei](http://www.huawei.com/), +[ARM](https://www.arm.com/), as well as grants from the -[National Science Foundation](http://www.nsf.gov/) (Awards +[National Science Foundation](https://www.nsf.gov/) (Awards CCF-0917167, ACI-1148125/1340293, CCF-1320112, and ACI-1550493). _Any opinions, findings and conclusions or recommendations expressed in this diff --git a/RELEASING b/RELEASING index bc2c9dc591..351594c49d 100644 --- a/RELEASING +++ b/RELEASING @@ -26,14 +26,19 @@ Here are the steps to follow to create a new release (version) of BLIS: 6. Update docs/ReleaseNotes.md file with body of finalized announcement and the date of the release. -7. Bump the version number: +7. Commit changes from steps 5 and 6. + +8. Bump the version number: $ ./build/bump-version.sh "0.3.2" -8. Push the new commits and new tag associated with the new version: + This will result in two new commits: a version file update and a CHANGELOG + file update. + +9. Push the new commits and new tag associated with the new version: $ git push $ git push --tag -9. Send finalized announcement to blis-devel. +10. Send finalized announcement to blis-devel. diff --git a/addon/gemmd/attic/bao_gemmd_bp_var2.c b/addon/gemmd/attic/bao_gemmd_bp_var2.c new file mode 100644 index 0000000000..a0040fec06 --- /dev/null +++ b/addon/gemmd/attic/bao_gemmd_bp_var2.c @@ -0,0 +1,602 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmd_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict d, inc_t incd, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemmd-like block-panel algorithm (object interface) ---------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bao_?gemmd_bp_var2(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bao_,gemmd_bp_var2); + +void bao_gemmd_bp_var2 + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_d = bli_obj_buffer_at_off( d ); + const inc_t incd = bli_obj_vector_inc( d ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_d, incd, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemmd-like block-panel algorithm (typed interface) ----------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + /* + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + */ \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + /* + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ + */ \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_d = incd; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict d_00 = d; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + /*ctype zero_local = *PASTEMAC(ch,0);*/ \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict d_pc = d_00 + pp * pcstep_d; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + d_pc, incd, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + d_pc, incd, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Call a wrapper to the kernel (which handles edge cases). */ \ + PASTECH2(bao_,ch,gemm_kernel) \ + ( \ + MR, \ + NR, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bao_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bao_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd_bp_var2 ) +GENTFUNC( float, s, gemmd_bp_var2 ) +GENTFUNC( double, d, gemmd_bp_var2 ) +GENTFUNC( scomplex, c, gemmd_bp_var2 ) +GENTFUNC( dcomplex, z, gemmd_bp_var2 ) + +// +// -- gemm-like microkernel wrapper -------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t kc_cur, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* Infer the datatype from the ctype. */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype zero = *PASTEMAC(ch,0); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + &zero, \ + ct, rs_ct, cs_ct, \ + aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_kernel ) +GENTFUNC( float, s, gemm_kernel ) +GENTFUNC( double, d, gemm_kernel ) +GENTFUNC( scomplex, c, gemm_kernel ) +GENTFUNC( dcomplex, z, gemm_kernel ) + diff --git a/addon/gemmd/attic/bli_gemm_ex.c b/addon/gemmd/attic/bli_gemm_ex.c new file mode 100644 index 0000000000..0f40d1cb39 --- /dev/null +++ b/addon/gemmd/attic/bli_gemm_ex.c @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // A switch to easily toggle whether we use the addon implementation + // of bao_gemmd() as the implementation for bli_gemm(). (This allows for + // easy testing of bao_gemmd() via the testsuite.) + if ( 1 ) + { + const dim_t k = bli_obj_width_after_trans( a ); + const num_t dt = bli_obj_dt( c ); + obj_t d; + + bli_obj_create( dt, k, 1, 1, k, &d ); + bli_setv( &BLIS_ONE, &d ); + //bli_randv( &d ); + + bao_gemmd_ex( alpha, a, &d, b, beta, c, cntx, rntm ); + + bli_obj_free( &d ); + return; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front + ( + alpha, a, b, beta, c, cntx, rntm, NULL + ); +} + diff --git a/addon/gemmd/bao_gemmd.c b/addon/gemmd/bao_gemmd.c new file mode 100644 index 0000000000..fadc526918 --- /dev/null +++ b/addon/gemmd/bao_gemmd.c @@ -0,0 +1,287 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Define the gemmd operation's object API ---------------------------------- +// + +void bao_gemmd + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + bao_gemmd_ex + ( + alpha, + a, + d, + b, + beta, + c, + NULL, + NULL + ); +} + +void bao_gemmd_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bao_gemmd_check( alpha, a, d, b, beta, c, cntx ); + + // -- bli_gemmd_front() ---------------------------------------------------- + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + + // Induce a transposition of A if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + + // Induce a transposition of B if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &b_local ) ) + { + bli_obj_induce_trans( &b_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &b_local ); + } + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // Spawn threads (if applicable), where bao_gemmd_int() is the thread entry + // point function for each thread. This also begins the process of creating + // the thrinfo_t tree, which contains thread communicators. + bao_l3_thread_decorator + ( + bao_gemmd_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + d, + &b_local, + beta, + &c_local, + cntx, + rntm + ); +} + +// +// -- Define the gemmd operation's thread entry point -------------------------- +// + +void bao_gemmd_int + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // In this function, we choose the gemmd implementation that is executed + // on each thread. + + // Call the block-panel algorithm. + bao_gemmd_bp_var1 + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm, + thread + ); +} + +// +// -- Define the gemmd operation's typed API ----------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* d, inc_t incd, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Determine the datatype (e.g. BLIS_FLOAT, BLIS_DOUBLE, etc.) based on + the macro parameter 'ch' (e.g. s, d, etc). */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao, ao, dd, bo, betao, co; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + /* Adjust the dimensions of matrices A and B according to the transa and + transb parameters. */ \ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + /* Create bufferless scalar objects and attach the provided scalar pointers + to those scalar objects. */ \ + bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ +\ + /* Create bufferless matrix objects and attach the provided matrix pointers + to those matrix objects. */ \ + bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_create_with_attached_buffer( dt, k, 1, d, incd, k, &dd ); \ + bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + /* Set the transposition/conjugation properties of the objects for matrices + A and B. */ \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + /* Call the object interface. */ \ + PASTECH(bao_,opname) \ + ( \ + &alphao, \ + &ao, \ + &dd, \ + &bo, \ + &betao, \ + &co \ + ); \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd ) +GENTFUNC( float, s, gemmd ) +GENTFUNC( double, d, gemmd ) +GENTFUNC( scomplex, c, gemmd ) +GENTFUNC( dcomplex, z, gemmd ) + diff --git a/addon/gemmd/bao_gemmd.h b/addon/gemmd/bao_gemmd.h new file mode 100644 index 0000000000..7c7466494d --- /dev/null +++ b/addon/gemmd/bao_gemmd.h @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// -- Prototype the gemmd operation's object API ------------------------------- +// + +BLIS_EXPORT_ADDON void bao_gemmd + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c + ); + +BLIS_EXPORT_ADDON void bao_gemmd_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// +// -- Prototype the gemmd operation's thread entry point ----------------------- +// + +void bao_gemmd_int + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// +// -- Prototype the gemmd operation's typed API -------------------------------- +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_ADDON void PASTECH2(bao_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* d, inc_t incd, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ); + +//INSERT_GENTPROT_BASIC0( gemmd ) +GENTPROT( float, s, gemmd ) +GENTPROT( double, d, gemmd ) +GENTPROT( scomplex, c, gemmd ) +GENTPROT( dcomplex, z, gemmd ) + diff --git a/addon/gemmd/bao_gemmd_bp_var1.c b/addon/gemmd/bao_gemmd_bp_var1.c new file mode 100644 index 0000000000..09e4df09e4 --- /dev/null +++ b/addon/gemmd/bao_gemmd_bp_var1.c @@ -0,0 +1,491 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmd_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict d, inc_t incd, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemmd-like block-panel algorithm (object interface) ---------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bao_?gemmd_bp_var1(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bao_,gemmd_bp_var1); + +void bao_gemmd_bp_var1 + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_d = bli_obj_buffer_at_off( d ); + const inc_t incd = bli_obj_vector_inc( d ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_d, incd, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemmd-like block-panel algorithm (typed interface) ----------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_d = incd; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict d_00 = d; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict d_pc = d_00 + pp * pcstep_d; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + d_pc, incd, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + d_pc, incd, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bao_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bao_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd_bp_var1 ) +GENTFUNC( float, s, gemmd_bp_var1 ) +GENTFUNC( double, d, gemmd_bp_var1 ) +GENTFUNC( scomplex, c, gemmd_bp_var1 ) +GENTFUNC( dcomplex, z, gemmd_bp_var1 ) + diff --git a/addon/gemmd/bao_gemmd_check.c b/addon/gemmd/bao_gemmd_check.c new file mode 100644 index 0000000000..864e9a1acb --- /dev/null +++ b/addon/gemmd/bao_gemmd_check.c @@ -0,0 +1,131 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bao_gemmd_check + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + err_t e_val; + + // Check object datatypes. + + e_val = bli_check_noninteger_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_noninteger_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( c ); + bli_check_error_code( e_val ); + + // Check scalar/vector/matrix type. + + e_val = bli_check_scalar_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_scalar_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_vector_object( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( c ); + bli_check_error_code( e_val ); + + // Check object buffers (for non-NULLness). + + e_val = bli_check_object_buffer( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( c ); + bli_check_error_code( e_val ); + + // Check object dimensions. + + e_val = bli_check_level3_dims( a, b, c ); + bli_check_error_code( e_val ); + + e_val = bli_check_vector_dim_equals( d, bli_obj_width_after_trans( a ) ); + bli_check_error_code( e_val ); + + // Check for consistent datatypes. + // NOTE: We only perform these tests when mixed datatype support is + // disabled. + + e_val = bli_check_consistent_object_datatypes( c, a ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, d ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, b ); + bli_check_error_code( e_val ); +} + diff --git a/frame/3/gemm/bli_gemm_int.h b/addon/gemmd/bao_gemmd_check.h similarity index 93% rename from frame/3/gemm/bli_gemm_int.h rename to addon/gemmd/bao_gemmd_check.h index 2bbe5480a6..243ec70c8c 100644 --- a/frame/3/gemm/bli_gemm_int.h +++ b/addon/gemmd/bao_gemmd_check.h @@ -32,16 +32,19 @@ */ -void bli_gemm_int + +// +// Prototype object-based check functions. +// + +void bao_gemmd_check ( obj_t* alpha, obj_t* a, + obj_t* d, obj_t* b, obj_t* beta, obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ); + cntx_t* cntx + ); diff --git a/addon/gemmd/bao_gemmd_var.h b/addon/gemmd/bao_gemmd_var.h new file mode 100644 index 0000000000..05ec45e07e --- /dev/null +++ b/addon/gemmd/bao_gemmd_var.h @@ -0,0 +1,119 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype the object-based variant interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTECH(bao_,opname) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* d, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +GENPROT( gemmd_bp_var1 ) + + +// +// Prototype the typed variant interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( gemmd_bp_var1 ) +GENTPROT( float, s, gemmd_bp_var1 ) +GENTPROT( double, d, gemmd_bp_var1 ) +GENTPROT( scomplex, c, gemmd_bp_var1 ) +GENTPROT( dcomplex, z, gemmd_bp_var1 ) + + +// +// Prototype the typed kernel interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_kernel ) +GENTPROT( float, s, gemm_kernel ) +GENTPROT( double, d, gemm_kernel ) +GENTPROT( scomplex, c, gemm_kernel ) +GENTPROT( dcomplex, z, gemm_kernel ) + diff --git a/addon/gemmd/bao_l3_packm_a.c b/addon/gemmd/bao_l3_packm_a.c new file mode 100644 index 0000000000..49bb34664c --- /dev/null +++ b/addon/gemmd/bao_l3_packm_a.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to blocks of A. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_A_BLOCK; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_a ) +GENTFUNC( float, s, packm_init_mem_a ) +GENTFUNC( double, d, packm_init_mem_a ) +GENTFUNC( scomplex, c, packm_init_mem_a ) +GENTFUNC( dcomplex, z, packm_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_a ) +GENTFUNC( float, s, packm_finalize_mem_a ) +GENTFUNC( double, d, packm_finalize_mem_a ) +GENTFUNC( scomplex, c, packm_finalize_mem_a ) +GENTFUNC( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + { \ + /* Pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_a ) +GENTFUNC( float, s, packm_init_a ) +GENTFUNC( double, d, packm_init_a ) +GENTFUNC( scomplex, c, packm_init_a ) +GENTFUNC( dcomplex, z, packm_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bao_,ch,packm_init_mem_a) \ + ( \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. */ \ + PASTECH2(bao_,ch,packm_init_a) \ + ( \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix A to the destination buffer chosen above. Here, the packed + matrix is stored to column-stored MR x k micropanels. */ \ + PASTECH2(bao_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + d, incd, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_a ) +GENTFUNC( float, s, packm_a ) +GENTFUNC( double, d, packm_a ) +GENTFUNC( scomplex, c, packm_a ) +GENTFUNC( dcomplex, z, packm_a ) + diff --git a/addon/gemmd/bao_l3_packm_a.h b/addon/gemmd/bao_l3_packm_a.h new file mode 100644 index 0000000000..b683b79d4a --- /dev/null +++ b/addon/gemmd/bao_l3_packm_a.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_a ) +GENTPROT( float, s, packm_init_mem_a ) +GENTPROT( double, d, packm_init_mem_a ) +GENTPROT( scomplex, c, packm_init_mem_a ) +GENTPROT( dcomplex, z, packm_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_a ) +GENTPROT( float, s, packm_finalize_mem_a ) +GENTPROT( double, d, packm_finalize_mem_a ) +GENTPROT( scomplex, c, packm_finalize_mem_a ) +GENTPROT( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_a ) +GENTPROT( float, s, packm_init_a ) +GENTPROT( double, d, packm_init_a ) +GENTPROT( scomplex, c, packm_init_a ) +GENTPROT( dcomplex, z, packm_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_a ) +GENTPROT( float, s, packm_a ) +GENTPROT( double, d, packm_a ) +GENTPROT( scomplex, c, packm_a ) +GENTPROT( dcomplex, z, packm_a ) + diff --git a/addon/gemmd/bao_l3_packm_b.c b/addon/gemmd/bao_l3_packm_b.c new file mode 100644 index 0000000000..c41b062b6e --- /dev/null +++ b/addon/gemmd/bao_l3_packm_b.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to panels of B. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_B_PANEL; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_b ) +GENTFUNC( float, s, packm_init_mem_b ) +GENTFUNC( double, d, packm_init_mem_b ) +GENTFUNC( scomplex, c, packm_init_mem_b ) +GENTFUNC( dcomplex, z, packm_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_b ) +GENTFUNC( float, s, packm_finalize_mem_b ) +GENTFUNC( double, d, packm_finalize_mem_b ) +GENTFUNC( scomplex, c, packm_finalize_mem_b ) +GENTFUNC( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + { \ + /* Pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_b ) +GENTFUNC( float, s, packm_init_b ) +GENTFUNC( double, d, packm_init_b ) +GENTFUNC( scomplex, c, packm_init_b ) +GENTFUNC( dcomplex, z, packm_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bao_,ch,packm_init_mem_b) \ + ( \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. */ \ + PASTECH2(bao_,ch,packm_init_b) \ + ( \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix B to the destination buffer chosen above. Here, the packed + matrix is stored to row-stored k x NR micropanels. */ \ + PASTECH2(bao_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + d, incd, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_b ) +GENTFUNC( float, s, packm_b ) +GENTFUNC( double, d, packm_b ) +GENTFUNC( scomplex, c, packm_b ) +GENTFUNC( dcomplex, z, packm_b ) + diff --git a/addon/gemmd/bao_l3_packm_b.h b/addon/gemmd/bao_l3_packm_b.h new file mode 100644 index 0000000000..9161604ce9 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_b.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_b ) +GENTPROT( float, s, packm_init_mem_b ) +GENTPROT( double, d, packm_init_mem_b ) +GENTPROT( scomplex, c, packm_init_mem_b ) +GENTPROT( dcomplex, z, packm_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_b ) +GENTPROT( float, s, packm_finalize_mem_b ) +GENTPROT( double, d, packm_finalize_mem_b ) +GENTPROT( scomplex, c, packm_finalize_mem_b ) +GENTPROT( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_b ) +GENTPROT( float, s, packm_init_b ) +GENTPROT( double, d, packm_init_b ) +GENTPROT( scomplex, c, packm_init_b ) +GENTPROT( dcomplex, z, packm_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_b ) +GENTPROT( float, s, packm_b ) +GENTPROT( double, d, packm_b ) +GENTPROT( scomplex, c, packm_b ) +GENTPROT( dcomplex, z, packm_b ) + diff --git a/addon/gemmd/bao_l3_packm_var.h b/addon/gemmd/bao_l3_packm_var.h new file mode 100644 index 0000000000..063e59e5f8 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var.h @@ -0,0 +1,69 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( packm_var1 ) +GENTPROT( float, s, packm_var1 ) +GENTPROT( double, d, packm_var1 ) +GENTPROT( scomplex, c, packm_var1 ) +GENTPROT( dcomplex, z, packm_var1 ) + +//INSERT_GENTPROT_BASIC0( packm_var2 ) +GENTPROT( float, s, packm_var2 ) +GENTPROT( double, d, packm_var2 ) +GENTPROT( scomplex, c, packm_var2 ) +GENTPROT( dcomplex, z, packm_var2 ) diff --git a/addon/gemmd/bao_l3_packm_var1.c b/addon/gemmd/bao_l3_packm_var1.c new file mode 100644 index 0000000000..24c0a2cc13 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var1.c @@ -0,0 +1,195 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 1 provides basic support for packing by calling packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTECH2(bao_,ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim, \ + panel_dim_max, \ + panel_len, \ + panel_len_max, \ + kappa_cast, \ + d, incd, \ + c_use, incc, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var1 ) +GENTFUNC( double, d, packm_var1 ) +GENTFUNC( scomplex, c, packm_var1 ) +GENTFUNC( dcomplex, z, packm_var1 ) + diff --git a/addon/gemmd/bao_l3_packm_var2.c b/addon/gemmd/bao_l3_packm_var2.c new file mode 100644 index 0000000000..830e499b31 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var2.c @@ -0,0 +1,245 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 2 is similar to variant 1, but inlines the contents of packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + bli_abort(); \ +\ + /* Perform the packing, taking conjc into account. */ \ + if ( bli_is_conj( conjc ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t d = 0; d < panel_dim; ++d ) \ + { \ + ctype* cld = c_use + (l )*ldc + (d )*incc; \ + ctype* pld = p_use + (l )*ldp + (d )*1; \ +\ + PASTEMAC(ch,copyjs)( *cld, *pld ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t d = 0; d < panel_dim; ++d ) \ + { \ + ctype* cld = c_use + (l )*ldc + (d )*incc; \ + ctype* pld = p_use + (l )*ldp + (d )*1; \ +\ + PASTEMAC(ch,copys)( *cld, *pld ); \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p_use + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p_use + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var2 ) +GENTFUNC( double, d, packm_var2 ) +GENTFUNC( scomplex, c, packm_var2 ) +GENTFUNC( dcomplex, z, packm_var2 ) + diff --git a/frame/1m/packm/bli_packm_cxk_3mis.c b/addon/gemmd/bao_packm_cxk.c similarity index 52% rename from frame/1m/packm/bli_packm_cxk_3mis.c rename to addon/gemmd/bao_packm_cxk.c index 9435f6a736..645f09d798 100644 --- a/frame/1m/packm/bli_packm_cxk_3mis.c +++ b/addon/gemmd/bao_packm_cxk.c @@ -34,19 +34,21 @@ #include "blis.h" -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +void PASTECH2(bao_,ch,opname) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t panel_dim, \ dim_t panel_dim_max, \ dim_t panel_len, \ dim_t panel_len_max, \ ctype* kappa, \ + ctype* d, inc_t incd, \ ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ + ctype* p, inc_t ldp, \ cntx_t* cntx \ ) \ { \ @@ -65,140 +67,133 @@ void PASTEMAC(ch,opname) \ \ /* If there exists a kernel implementation for the micro-panel dimension provided, we invoke the implementation. Otherwise, we use scal2m. */ \ - if ( f != NULL ) \ + /* NOTE: We've disabled calling packm micro-kernels from the context for + this implementation. To re-enable, change FALSE to TRUE in the + conditional below. */ \ + if ( f != NULL && FALSE ) \ { \ f \ ( \ conja, \ + schema, \ panel_dim, \ panel_len, \ panel_len_max, \ kappa, \ a, inca, lda, \ - p, is_p, ldp, \ + p, ldp, \ cntx \ ); \ } \ else \ { \ - /* Treat the micro-panel as panel_dim x panel_len and column-stored - (unit row stride). */ \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa ) ) \ + bli_abort(); \ \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ + if ( d == NULL ) \ + { \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *ali, *pli ); \ + } \ + } \ + } \ + } \ + else /* if ( d != NULL ) */ \ + { \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* dl = d + (l )*incd; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + /* Note that ali must be the second operand here since + that is what is conjugated by scal2js. */ \ + PASTEMAC(ch,scal2js)( *dl, *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* dl = d + (l )*incd; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,scal2s)( *ali, *dl, *pli ); \ + } \ + } \ + } \ + } \ \ /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ if ( panel_dim < panel_dim_max ) \ { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = panel_dim; \ - const dim_t m_edge = panel_dim_max - i; \ - const dim_t n_edge = panel_len_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p + (i )*1; \ \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ + PASTEMAC(ch,set0s_mxn) \ ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ m_edge, \ n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ + p_edge, 1, ldp \ ); \ } \ \ /* If panel_len < panel_len_max, then we zero those unused columns. */ \ if ( panel_len < panel_len_max ) \ { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = panel_len; \ - const dim_t m_edge = panel_dim_max; \ - const dim_t n_edge = panel_len_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p + (j )*ldp; \ \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ + PASTEMAC(ch,set0s_mxn) \ ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ m_edge, \ n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ + p_edge, 1, ldp \ ); \ } \ } \ } -INSERT_GENTFUNCCO_BASIC0( packm_cxk_3mis ) +//INSERT_GENTFUNC_BASIC0( packm_cxk ) +GENTFUNC( float, s, packm_cxk ) +GENTFUNC( double, d, packm_cxk ) +GENTFUNC( scomplex, c, packm_cxk ) +GENTFUNC( dcomplex, z, packm_cxk ) diff --git a/frame/1m/packm/bli_packm_cxk_4mi.h b/addon/gemmd/bao_packm_cxk.h similarity index 83% rename from frame/1m/packm/bli_packm_cxk_4mi.h rename to addon/gemmd/bao_packm_cxk.h index 244f2d045e..3e977a7cc2 100644 --- a/frame/1m/packm/bli_packm_cxk_4mi.h +++ b/addon/gemmd/bao_packm_cxk.h @@ -33,21 +33,27 @@ */ -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ \ -void PASTEMAC(ch,varname) \ +void PASTECH2(bao_,ch,varname) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t panel_dim, \ dim_t panel_dim_max, \ dim_t panel_len, \ dim_t panel_len_max, \ ctype* kappa, \ + ctype* d, inc_t incd, \ ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ + ctype* p, inc_t ldp, \ cntx_t* cntx \ ); -INSERT_GENTPROTCO_BASIC0( packm_cxk_4mi ) +//INSERT_GENTPROT_BASIC0( packm_cxk ) +GENTPROT( float, s, packm_cxk ) +GENTPROT( double, d, packm_cxk ) +GENTPROT( scomplex, c, packm_cxk ) +GENTPROT( dcomplex, z, packm_cxk ) diff --git a/addon/gemmd/gemmd.h b/addon/gemmd/gemmd.h new file mode 100644 index 0000000000..cab61bd181 --- /dev/null +++ b/addon/gemmd/gemmd.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef GEMMD_H +#define GEMMD_H + +// This header should contain (or #include) any definitions that must be +// folded into blis.h. + +#include "bao_gemmd.h" +#include "bao_gemmd_check.h" +#include "bao_gemmd_var.h" + +#include "bao_l3_packm_a.h" +#include "bao_l3_packm_b.h" +#include "bao_l3_packm_var.h" + +#include "bao_packm_cxk.h" + +#include "bao_l3_decor.h" + + +#endif diff --git a/addon/gemmd/thread/bao_l3_decor.h b/addon/gemmd/thread/bao_l3_decor.h new file mode 100644 index 0000000000..b4fd2b9b76 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor.h @@ -0,0 +1,75 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_H +#define BLIS_SBX_L3_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef void (*l3sbxint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading. +#include "bao_l3_decor_single.h" +#include "bao_l3_decor_openmp.h" +#include "bao_l3_decor_pthreads.h" + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.c b/addon/gemmd/thread/bao_l3_decor_openmp.c new file mode 100644 index 0000000000..1aca8de275 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_openmp.c @@ -0,0 +1,140 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy thread entry function, which is needed in the pthreads +// version, so that when building Windows DLLs (with OpenMP enabled or with +// no multithreading) we don't risk having an unresolved symbol. +void* bao_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.h b/addon/gemmd/thread/bao_l3_decor_openmp.h new file mode 100644 index 0000000000..9c956d7c36 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_OPENMP_H +#define BLIS_SBX_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_pthreads.c b/addon/gemmd/thread/bao_l3_decor_pthreads.c new file mode 100644 index 0000000000..587b8400f1 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_pthreads.c @@ -0,0 +1,220 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3sbxint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* d; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point function for additional threads. +void* bao_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3sbxint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* d = data->d; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t r_val; + + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].d = d; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bao_l3_thread_entry, &datas[tid] ); + else + bao_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_pthreads.h b/addon/gemmd/thread/bao_l3_decor_pthreads.h new file mode 100644 index 0000000000..69adec45ee --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_PTHREADS_H +#define BLIS_SBX_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bao_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_single.c b/addon/gemmd/thread/bao_l3_decor_single.c new file mode 100644 index 0000000000..d60891d65b --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_single.c @@ -0,0 +1,143 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_pba_rntm_set_pba( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_single.h b/addon/gemmd/thread/bao_l3_decor_single.h new file mode 100644 index 0000000000..211a43a894 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_SINGLE_H +#define BLIS_SBX_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/blastest/Makefile b/blastest/Makefile index 4659fcfee7..b4b40a714b 100644 --- a/blastest/Makefile +++ b/blastest/Makefile @@ -136,7 +136,7 @@ CFLAGS += -Wno-maybe-uninitialized -Wno-parentheses -Wfatal-errors \ -I$(INC_PATH) -DHAVE_BLIS_H # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Override the location of the check-blastest.sh script. #BLASTEST_CHECK := ./check-blastest.sh diff --git a/blastest/f2c/lread.c b/blastest/f2c/lread.c index 74df12b9ba..cdb4ee2a4b 100644 --- a/blastest/f2c/lread.c +++ b/blastest/f2c/lread.c @@ -350,7 +350,8 @@ static int nmL_getc(void) static int nmL_ungetc(int x, FILE *f) { - f = f; /* banish non-use warning */ + /* f = f;*/ /* banish non-use warning */ + ( void )f; return *--nmL_next = x; } diff --git a/blis.pc.in b/blis.pc.in new file mode 100644 index 0000000000..57dbafec45 --- /dev/null +++ b/blis.pc.in @@ -0,0 +1,11 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: BLIS +Description: BLAS-like Library Instantiation Software Framework +Version: @PACKAGE_VERSION@ +Libs: -L${libdir} -lblis +Libs.private: @LDFLAGS@ +Cflags: -I${includedir}/blis diff --git a/build/bli_addon.h.in b/build/bli_addon.h.in new file mode 100644 index 0000000000..36a8e29bd1 --- /dev/null +++ b/build/bli_addon.h.in @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_ADDON_H +#define BLIS_ADDON_H + +#if @enable_addons@ +#define BLIS_ENABLE_ADDONS +#else +#define BLIS_DISABLE_ADDONS +#endif + +// Enabled addons +@addon_list_includes@ + +#endif diff --git a/build/bli_config.h.in b/build/bli_config.h.in index 7debcfa38b..fa6bbbe12e 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,6 +45,12 @@ // Enabled kernel sets (kernel_list) @kernel_list_defines@ +#if @enable_system@ +#define BLIS_ENABLE_SYSTEM +#else +#define BLIS_DISABLE_SYSTEM +#endif + #if @enable_openmp@ #define BLIS_ENABLE_OPENMP #endif @@ -135,12 +141,24 @@ #endif #endif +#if @enable_sup_handling@ +#define BLIS_ENABLE_SUP_HANDLING +#else +#define BLIS_DISABLE_SUP_HANDLING +#endif + #if @enable_memkind@ #define BLIS_ENABLE_MEMKIND #else #define BLIS_DISABLE_MEMKIND #endif +#if @enable_trsm_preinversion@ +#define BLIS_ENABLE_TRSM_PREINVERSION +#else +#define BLIS_DISABLE_TRSM_PREINVERSION +#endif + #if @enable_pragma_omp_simd@ #define BLIS_ENABLE_PRAGMA_OMP_SIMD #else @@ -159,23 +177,11 @@ #define BLIS_DISABLE_SHARED #endif -#if !defined(BLIS_ENABLE_SHARED) - #define BLIS_EXPORT +#if @complex_return_intel@ +#define BLIS_ENABLE_COMPLEX_RETURN_INTEL #else - #if defined(_WIN32) || defined(__CYGWIN__) - #ifdef BLIS_IS_BUILDING_LIBRARY - #define BLIS_EXPORT __declspec(dllexport) - #else - #define BLIS_EXPORT __declspec(dllimport) - #endif - #elif defined(__GNUC__) && __GNUC__ >= 4 - #define BLIS_EXPORT __attribute__ ((visibility ("default"))) - #else - #define BLIS_EXPORT - #endif +#define BLIS_DISABLE_COMPLEX_RETURN_INTEL #endif -#define BLIS_EXPORT_BLIS BLIS_EXPORT -#define BLIS_EXPORT_BLAS BLIS_EXPORT #endif diff --git a/build/config.mk.in b/build/config.mk.in index c68601c670..56d6211c24 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2022, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -89,6 +90,16 @@ endif CC_VENDOR := @CC_VENDOR@ CC := @CC@ +# Important C compiler ranges. +GCC_OT_4_9_0 := @gcc_older_than_4_9_0@ +GCC_OT_6_1_0 := @gcc_older_than_6_1_0@ +GCC_OT_9_1_0 := @gcc_older_than_9_1_0@ +GCC_OT_10_1_0 := @gcc_older_than_10_1_0@ +CLANG_OT_9_0_0 := @clang_older_than_9_0_0@ +CLANG_OT_12_0_0 := @clang_older_than_12_0_0@ +AOCC_OT_2_0_0 := @aocc_older_than_2_0_0@ +AOCC_OT_3_0_0 := @aocc_older_than_3_0_0@ + # The C++ compiler. NOTE: A C++ is typically not needed. CXX := @CXX@ @@ -98,6 +109,9 @@ RANLIB := @RANLIB@ # Archiver. AR := @AR@ +# Python Interpreter +PYTHON := @PYTHON@ + # Preset (required) CFLAGS and LDFLAGS. These variables capture the value # of the CFLAGS and LDFLAGS environment variables at configure-time (and/or # the value of CFLAGS/LDFLAGS if either was specified on the command line). @@ -109,19 +123,42 @@ LDFLAGS_PRESET := @ldflags_preset@ # The level of debugging info to generate. DEBUG_TYPE := @debug_type@ +# Whether operating system support was requested via --enable-system. +ENABLE_SYSTEM := @enable_system@ + # The requested threading model. THREADING_MODEL := @threading_model@ # Whether the compiler supports "#pragma omp simd" via the -fopenmp-simd option. PRAGMA_OMP_SIMD := @pragma_omp_simd@ -# The install libdir, includedir, and shareddir values from configure tell -# us where to install the libraries, header files, and public makefile -# fragments, respectively. Notice that we support the use of DESTDIR so that -# advanced users may install to a temporary location. -INSTALL_LIBDIR := $(DESTDIR)@install_libdir@ -INSTALL_INCDIR := $(DESTDIR)@install_incdir@ -INSTALL_SHAREDIR := $(DESTDIR)@install_sharedir@ +# The installation prefix, exec_prefix, libdir, includedir, and shareddir +# values from configure tell us where to install the libraries, header files, +# and public makefile fragments. We must first assign each substituted +# @anchor@ to its own variable. Why? Because the subsitutions may contain +# unevaluated variable expressions. For example, '@libdir@' may be replaced +# with '${exec_prefix}/lib'. By assigning the anchors to variables first, and +# then assigning them to their final INSTALL_* variables, we allow prefix and +# exec_prefix to be used in the definitions of exec_prefix, libdir, +# includedir, and sharedir. +prefix := @prefix@ +exec_prefix := @exec_prefix@ +libdir := @libdir@ +includedir := @includedir@ +sharedir := @sharedir@ + +# Notice that we support the use of DESTDIR so that advanced users may install +# to a temporary location. +INSTALL_LIBDIR := $(DESTDIR)$(libdir) +INSTALL_INCDIR := $(DESTDIR)$(includedir) +INSTALL_SHAREDIR := $(DESTDIR)$(sharedir) + +#$(info prefix = $(prefix) ) +#$(info exec_prefix = $(exec_prefix) ) +#$(info libdir = $(libdir) ) +#$(info includedir = $(includedir) ) +#$(info sharedir = $(sharedir) ) +#$(error .) # Whether to output verbose command-line feedback as the Makefile is # processed. @@ -135,11 +172,18 @@ BUILDING_OOT := @configured_oot@ ARG_MAX_HACK := @enable_arg_max_hack@ # Whether to build the static and shared libraries. -# Note the "MK_" prefix, which helps differentiate these variables from +# NOTE: The "MK_" prefix, which helps differentiate these variables from # their corresonding cpp macros that use the BLIS_ prefix. MK_ENABLE_STATIC := @enable_static@ MK_ENABLE_SHARED := @enable_shared@ +# Whether to use an install_name based on @rpath. +MK_ENABLE_RPATH := @enable_rpath@ + +# Whether to export all symbols within the shared library, even those symbols +# that are considered to be for internal use only. +EXPORT_SHARED := @export_shared@ + # Whether to enable either the BLAS or CBLAS compatibility layers. MK_ENABLE_BLAS := @enable_blas@ MK_ENABLE_CBLAS := @enable_cblas@ @@ -147,13 +191,21 @@ MK_ENABLE_CBLAS := @enable_cblas@ # Whether libblis will depend on libmemkind for certain memory allocations. MK_ENABLE_MEMKIND := @enable_memkind@ +# The names of the addons to include when building BLIS. If empty, no addons +# will be included. +ADDON_LIST := @addon_list@ + # The name of a sandbox defining an alternative gemm implementation. If empty, # no sandbox will be used and the conventional gemm implementation will remain # enabled. SANDBOX := @sandbox@ -# The name of the pthreads library. +# The name of the pthreads library. If --disable-system was given, then this +# variable is set to the empty value. LIBPTHREAD := @libpthread@ +# Whether we should use AMD-customized versions of certain framework files. +ENABLE_AMD_FRAME_TWEAKS := @enable_amd_frame_tweaks@ + # end of ifndef CONFIG_MK_INCLUDED conditional block endif diff --git a/build/detect/config/config_detect.c b/build/detect/config/config_detect.c index 12b93162af..5e29defe15 100644 --- a/build/detect/config/config_detect.c +++ b/build/detect/config/config_detect.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,11 +33,39 @@ */ -#define BLIS_EXPORT_BLIS -#include "bli_system.h" -#include "bli_type_defs.h" -#include "bli_arch.h" -#include "bli_cpuid.h" +// NOTE: This file will likely only ever get compiled as part of the BLIS +// configure script, and therefore BLIS_CONFIGURETIME_CPUID is guaranteed to +// be #defined. However, we preserve the cpp conditional for consistency with +// the other three files mentioned above. +#ifdef BLIS_CONFIGURETIME_CPUID + + // NOTE: If you need to make any changes to this cpp branch, it's probably + // the case that you also need to modify bli_arch.c, bli_cpuid.c, and + // bli_env.c. Don't forget to update these other files as needed! + + // The BLIS_ENABLE_SYSTEM macro must be defined so that the correct cpp + // branch in bli_system.h is processed. (This macro is normally defined in + // bli_config.h.) + #define BLIS_ENABLE_SYSTEM + + // Use C-style static inline functions for any static inline functions that + // happen to be defined by the headers below. (This macro is normally defined + // in bli_config_macro_defs.h.) + #define BLIS_INLINE static + + // Since we're not building a shared library, we can forgo the use of the + // BLIS_EXPORT_BLIS annotations by #defining them to be nothing. (This macro + // is normally defined in bli_config_macro_defs.h.) + #define BLIS_EXPORT_BLIS + + #include "bli_system.h" + #include "bli_type_defs.h" + #include "bli_arch.h" + #include "bli_cpuid.h" + //#include "bli_env.h" +#else + #include "blis.h" +#endif int main( int argc, char** argv ) { diff --git a/build/detect/config/old/cpuid_x86.c b/build/detect/config/old/cpuid_x86.c index 1805d9643a..f4985e3914 100644 --- a/build/detect/config/old/cpuid_x86.c +++ b/build/detect/config/old/cpuid_x86.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2015, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/flatten-headers.py b/build/flatten-headers.py index 9599278d2a..563725a7e9 100755 --- a/build/flatten-headers.py +++ b/build/flatten-headers.py @@ -244,10 +244,24 @@ def flatten_header( inputfile, header_dirpaths, cursp ): # directive. header_path = get_header_path( header, header_dirpaths ) - # If the header was found, we recurse. Otherwise, we output - # the #include directive with a comment indicating that it - # was skipped. - if header_path: + # First, check if the header is our root header (and if so, ignore it). + # Otherwise, if the header was found, we recurse. Otherwise, we output + # the #include directive with a comment indicating that it as skipped + if header == root_inputfile: + + markl = result.group(1) + markr = result.group(3) + + echov2( "%sthis is the root header '%s'; commenting out / skipping." \ + % ( cursp, header ) ) + + # If the header found is our root header, then we cannot + # recurse into it lest we enter an infinite loop. Output the + # line but make sure it's commented out entirely. + ostring += "%s #include %c%s%c %c" \ + % ( skipstr, markl, header, markr, '\n' ) + + elif header_path: echov2( "%slocated file '%s'; recursing." \ % ( cursp, header_path ) ) @@ -327,6 +341,7 @@ def find_header_dirs( dirpath ): recursive_flag = None verbose_flag = None regex = None +root_inputfile = None def main(): @@ -336,6 +351,7 @@ def main(): global recursive_flag global verbose_flag global regex + global root_inputfile # Obtain the script name. path, script_name = os.path.split(sys.argv[0]) @@ -397,6 +413,10 @@ def main(): temp_dir = args[2] dir_list = args[3] + # Save the filename (basename) part of the input file (or root file) into a + # global variable that we can access later from within flatten_header(). + root_inputfile = os.path.basename( inputfile ) + # Separate the directories into distinct strings. dir_list = dir_list.split() diff --git a/build/gen-make-frags/gen-make-frag.sh b/build/gen-make-frags/gen-make-frag.sh index 4d8cb408d0..e411fa8d95 100755 --- a/build/gen-make-frags/gen-make-frag.sh +++ b/build/gen-make-frags/gen-make-frag.sh @@ -417,8 +417,9 @@ main() # The arguments to this function. They'll get assigned meaningful # values after getopts. - mkfile_frag_tmpl_path="" root_dir="" + frag_dir="" + mkfile_frag_tmpl_path="" suffix_file="" ignore_file="" diff --git a/build/gen-make-frags/ignore_list b/build/gen-make-frags/ignore_list index ccdd18f644..3561710b4f 100644 --- a/build/gen-make-frags/ignore_list +++ b/build/gen-make-frags/ignore_list @@ -5,3 +5,4 @@ other temp tmp test +p10_testsuite \ No newline at end of file diff --git a/build/irun.py b/build/irun.py index d9d1e6b778..429981603c 100755 --- a/build/irun.py +++ b/build/irun.py @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2018, The University of Texas at Austin -# Copyright (C) 2018, Advanced Micro Devices, Inc. +# Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/build/libblis-symbols.def b/build/libblis-symbols.def index e1bfce807e..8d29d73b25 100644 --- a/build/libblis-symbols.def +++ b/build/libblis-symbols.def @@ -1297,17 +1297,16 @@ bli_malloc_user bli_mbool_create bli_mbool_free bli_mbool_init -bli_membrk_acquire_m -bli_membrk_compute_pool_block_sizes -bli_membrk_compute_pool_block_sizes_dt -bli_membrk_finalize -bli_membrk_finalize_pools -bli_membrk_init -bli_membrk_init_pools -bli_membrk_pool_size -bli_membrk_query -bli_membrk_release -bli_membrk_rntm_set_membrk +bli_pba_acquire_m +bli_pba_compute_pool_block_sizes +bli_pba_compute_pool_block_sizes_dt +bli_pba_finalize +bli_pba_finalize_pools +bli_pba_init +bli_pba_init_pools +bli_pba_pool_size +bli_pba_query +bli_pba_release bli_memsys_finalize bli_memsys_init bli_mkherm diff --git a/build/recu-sed.sh b/build/recu-sed.sh new file mode 100755 index 0000000000..e7a1d43db3 --- /dev/null +++ b/build/recu-sed.sh @@ -0,0 +1,488 @@ +#!/bin/bash + +# +# recursive-sed.sh +# +# Field G. Van Zee +# + +print_usage() +{ + # Echo usage info + echo " " + echo " "$script_name + echo " " + echo " Field G. Van Zee" + echo " " + echo " Recusively descend a directory tree and perform sed commands, either on" + echo " the filename or the file contents, or both." + echo " " + echo " Usage:" + echo " ${script_name} [options]" + echo " " + echo " The following options are accepted:" + echo " " + echo " -d " + echo " Dry run. Go through all the motions, but don't actually" + echo " apply any of the sed expressions to file names or contents." + echo " -N " + echo " Do not proceed recursively into subdirectories; consider" + echo " only the files within the current directory. Default" + echo " behavior is to act recursively." + echo " -h " + echo " Consider hidden files and directories. Default behavior is" + echo " to ignore them." + echo " -n " + echo " Use svn mv instead of mv when renaming the file." + echo " Notice that this only applies if the filename changes." + echo " -p pattern " + echo " Specifies the filename pattern, as would be given to the" + echo " ls utility, to limit which files are affected. Default is" + echo " the to consider all files present." + echo " -r dir" + echo " The root directory for the recursive action to be performed." + echo " Default is to use the current working directory." + echo " -v [0|1|2]" + echo " verboseness level" + echo " level 0: silent (no output)" + echo " level 1: default (one line per directory; supress ls stderr)" + echo " level 2: verbose (one line per directory; show ls stderr)" + echo " " + echo " At least one of the following option-argument pairs is required:" + echo " " + echo " -f sed_expr " + echo " Specifies the sed expression that will be applied to the" + echo " filenames of the files touched by the script. This expression" + echo " must be a search-and-replace pattern." + echo " -c sed_expr " + echo " Specifies the sed expression that will be applied to the" + echo " contents of the files touched by the script. This expression" + echo " should be a search-and-replace pattern." + echo " -s sed_script" + echo " Specifies an arbitrary sed script that will be applied to the" + echo " file contents of the files touched by the script." + echo " " + echo " Note: -c and -s options are mutually exclusive." + echo " " + + # Exit with non-zero exit status + exit 1 +} + + + + +perform_sed() +{ + # Variables set by getopts. + local exist_dir="$1" + + #echo "exist_dir: $exist_dir" + + # The suffix used to create temporary files + local temp_file_suffix="sed_temp" + + # Check that exist_dir actually exists and is a directory + if [ ! -d "${exist_dir}" ]; then + echo "${script_name}: ${exist_dir} does not seem to be a valid directory." + exit 1 + fi + + # Check that the filename sed expression, if given, begins with an 's'. + if [ -n "$filename_sed_expr" ]; then + + # If it's a valid search-and-replace expression, this should return an 's'. + filename_sed_char=${filename_sed_expr%%/*} + + if [ "$filename_sed_char" != "s" ]; then + echo "${script_name}: sed expression given with -f must be search-and-replace." + exit 1 + fi + fi + + # Check that the sed script, if given, exists. + if [ -n "$contents_sed_script" ]; then + + if [ ! -f ${contents_sed_script} ]; then + echo "${script_name}: ${contents_sed_script} is not a regular file or does not exist." + exit 1 + fi + fi + + # Assume that the sed expression is a search-and-replace. Extract the patterns + # to match on. (Arbitrary sed expressions should be applied through a sed script.) + if [ "$filename_sed_expr" != "" ]; then + filename_sed_match=${filename_sed_expr#s/} + filename_sed_match=${filename_sed_match%%/*} + fi + + + # Get the list of source files in the directory given. Supress stderr if + # level 0 or 1 verbosity was requested. + #if [ "$verbose_level" != "2" ]; then + # old_filepaths=$(ls -d -b ${exist_dir}/${filename_pattern} 2> /dev/null) + #else + # old_filepaths="$(ls -d -b ${exist_dir}/${filename_pattern})" + #fi + + #echo $old_filepaths + #echo "$exist_dir/$filename_pattern" + + #for old_filepath in $old_filepaths; do + #echo "exist_dir: $exist_dir" + + # Find all files that match the pattern in the current directory. + find "${exist_dir}" -maxdepth 1 -name "${filename_pattern}" -print | while read old_filepath + do + #echo "old_filepath: $old_filepath" + + # Skip the current directory. + if [ "${old_filepath}" == "${exist_dir}" ]; then + continue + fi + + # Skip any non-regular files. + if [ ! -f "$old_filepath" ]; then + + # And say we are doing so if verboseness was requested. + if [ "$verbose_level" = "2" ]; then + echo "${script_name}: Ignoring $old_filepath" + fi + continue + fi + + # Strip exist_dir from filename. + old_filename=${old_filepath##*/} + + # Strip the filename from old_filepath to leave the directory path. + old_dirpath=${old_filepath%/*} + + # Create a new filename from the old one. If a filename sed expression was given, + # it will be applied now. + if [ "$filename_sed_expr" != "" ]; then + new_filename=$(echo "${old_filename}" | sed "${filename_sed_expr}") + else + new_filename="${old_filename}" + fi + + #echo "new_filename: $new_filename" + + # Create the filepath to the new file location. + new_filepath="${old_dirpath}/${new_filename}" + #echo "new_filepath: $new_filepath" + + # Grep for the filename pattern within the filename of the current file. + if [ "$filename_sed_expr" != "" ]; then + grep_filename=$(echo "${old_filename}" | grep "${filename_sed_match}") + fi + + + # If we are not performing a dry run, proceed. + if [ -z "$dry_run_flag" ]; then + + # Save the old file permissions so we can re-apply them to the + # new file if its contents change (ie: if it's not just a 'mv', + # which inherently preserves file permissions). + old_perms=$(stat -c %a "${old_filepath}") + + # If the old and new filepaths are different, then we start off by + # renaming the file. (Otherwise, if the old and new filepaths are + # identical, then we don't need to do anything to the file.) If + # the user requested that we use svn mv, then do that, otherwise we + # use regular mv. + if [ "${old_filepath}" != "${new_filepath}" ]; then + + if [ -n "$use_svn_mv_flag" ]; then + + svn mv "${old_filepath}" "${new_filepath}" + else + + mv -f "${old_filepath}" "${new_filepath}" + fi + fi + #else + + # A dry run still needs the act upon the "new" file, so if the + # filepaths are different, simply set the new filepath to the + # old one. (We won't need the previous value of new_filepath + # anymore.) + #if [ "${old_filepath}" != "${new_filepath}" ]; then + # new_filepath="${old_filepath}" + #fi + fi + + # Handle the cases that might change the contents of the file. + if [ "$contents_sed_expr" != "" ] || + [ "$contents_sed_script" != "" ]; then + + # Execute the sed command based on whether the sed action was given + # as a command line expression or a script residing in a file. + if [ "$contents_sed_script" != "" ]; then + + # Perform the action, saving the result to a temporary file. + cat "${new_filepath}" | sed -f ${contents_sed_script} \ + > ${new_filepath}.${temp_file_suffix} + + elif [ "$contents_sed_expr" != "" ]; then + + # Perform the action, saving the result to a temporary file. + cat "${new_filepath}" | sed -e "${contents_sed_expr}" \ + > ${new_filepath}.${temp_file_suffix} + fi + + # Check the difference. + file_diff=$(diff "${new_filepath}" "${new_filepath}.${temp_file_suffix}") + + + # If we are not performing a dry run, proceed. + if [ -z "$dry_run_flag" ]; then + + # If the file contents change. + if [ -n "$file_diff" ]; then + + # Apply the old file permissions to the new file (before we + # potentially overwrite the old file with the new one). + chmod ${old_perms} "${new_filepath}.${temp_file_suffix}" + + # Apply the file contents changes to the new filepath (which may + # or may not be the same as the old filepath). + mv -f "${new_filepath}.${temp_file_suffix}" "${new_filepath}" + + else + # Otherwise remove the new temporary file since it is identical + # to the original. + rm -f "${new_filepath}.${temp_file_suffix}" + fi + else + # Simply remove the file since we are only performing a dry run. + rm -f "${new_filepath}.${temp_file_suffix}" + fi + + fi + + # Check for dos2unix. If it's not here, we'll just substitute cat. + #type_dos2unix=$(type -path dos2unix) + #if [ -n "$type_dos2unix" ]; then + # dos2unix -q ${new_filepath} + #fi + + # Create a string that indicates what we are changing. We'll use this in + # the verbose progress echo to indicate how the file is or would be changed. + if [ -n "$grep_filename" ] && [ -n "$file_diff" ]; then + which_matches="filename/contents" + file_touched="yes" + elif [ -n "$grep_filename" ] && [ -z "$file_diff" ]; then + which_matches="filename " + file_touched="yes" + elif [ -z "$grep_filename" ] && [ -n "$file_diff" ]; then + which_matches=" contents" + file_touched="yes" + else + which_matches="" + file_touched="no" + fi + + # Be verbose, if requested, about which file we're looking at. + if [ "$verbose_level" != "0" ]; then + + # But we only need to output a line if the file was touched. + if [ "$file_touched" != "no" ]; then + + # Construct a relative filepath by stripping the initial root + # directory so that the output does not span as many columns on + # the terminal. + rel_old_filepath=${old_filepath#${initial_root_dir}/} + + # Add a "dry run" condition to the output if we're doing a dry-run + # so that the user knows we didn't really change anything. + if [ -z "$dry_run_flag" ]; then + echo "$script_name: Changing [${which_matches}] of ${rel_old_filepath}" + else + echo "$script_name: Changing (dry run) [${which_matches}] of ${rel_old_filepath}" + fi + fi + fi + + done + + # Exit peacefully. + return 0 +} + + + + +recursive_sed() +{ + # Local variable declarations + local item sub_items curr_dir this_dir + + + # Extract our argument + curr_dir="$1" + + + # Call our function to perform the sed operations on the files in the + # directory given. + perform_sed "${curr_dir}" + + + # If we were asked to act recursively, then continue processing + # curr_dir's contents. + if [ "$recursive_flag" = "1" ]; then + + # Get a listing of items in the directory according to the hidden + # files/directories flag. + if [ -n "$hidden_files_dirs_flag" ]; then + + # Get a listing of the directories in curr_dir (including hidden + # files and directories). + sub_items=$(ls -a "$curr_dir") + + else + + # Get a listing of the directories in curr_dir. + sub_items=$(ls "$curr_dir") + fi + + #echo "sub_items: $sub_items" + + # Descend into the contents of curr_dir, calling recursive_sed on + # any items that are directories. + find "${curr_dir}" -maxdepth 1 -name "*" -print | while read item + do + + #echo "conisdering item: $item" + + # Skip the current directory. + if [ "${item}" == "${curr_dir}" ]; then + continue + fi + + # If item is a directory, descend into it. + if [ -d "$item" ]; then + + #echo "item is dir: $item" + + recursive_sed "$item" + fi + done + + fi + + + # Return peacefully + return 0 +} + + + + +main() +{ + # Variables set by getopts. + dry_run_flag="" + hidden_files_dirs_flag="" + use_svn_mv_flag="" + filename_pattern="" + root_dir="" + initial_root_dir="" + verbose_level="" + filename_sed_expr="" + contents_sed_expr="" + contents_sed_script="" + + recursive_flag="1" + + + # Get the script name + script_name=${0##*/} + + + # Local variable declarations. + local item sub_items this_dir + + + # Process our command line options. + while getopts ":c:df:hp:r:s:nNv:" opt; do + case $opt in + d ) dry_run_flag="1" ;; + h ) hidden_files_dirs_flag="1" ;; + n ) use_svn_mv_flag="1" ;; + N ) recursive_flag="0" ;; + v ) verbose_level="$OPTARG" ;; + p ) filename_pattern="$OPTARG" ;; + r ) root_dir="$OPTARG" ;; + f ) filename_sed_expr="$OPTARG" ;; + c ) contents_sed_expr="$OPTARG" ;; + s ) contents_sed_script="$OPTARG" ;; + \? ) print_usage + esac + done + shift $(($OPTIND - 1)) + + + # Make sure we've parsed all command line arguments by now. + if [ $# != "0" ]; then + echo "${script_name}: Unparsed command line arguments! Try running with no arguments for help." + exit 1 + fi + + + # Make sure we received at least one of the required options. + if [ -z "$filename_sed_expr" ] && + [ -z "$contents_sed_expr" ] && + [ -z "$contents_sed_script" ]; then + print_usage + fi + + + # Make sure that both a file contents sed expression and sed script were + # not given. + if [ "$contents_sed_expr" != "" ] && + [ "$contents_sed_script" != "" ] ; then + echo "${script_name}: The -c and -s options may not be used at the same time." + exit 1 + fi + + + # Make sure that verboseness level is valid. + if [ "$verbose_level" != "0" ] && + [ "$verbose_level" != "1" ] && + [ "$verbose_level" != "2" ]; then + verbose_level="1" + fi + + # Prepare the filename pattern arguments to perform_sed(). + if [ "$filename_pattern" = "" ] ; then + filename_pattern='*' + fi + + # Prepare the directory arguments to perform_sed(). + if [ "$root_dir" != "" ] ; then + + # Strip / from end of directory paths, if there is one. + root_dir=${root_dir%/} + else + root_dir=$PWD + fi + initial_root_dir=${root_dir} + + + #echo "root_dir: $root_dir" + + + # Begin recursing on the root directory. + recursive_sed "$root_dir" + + + # Exit peacefully + return 0 +} + + + + +# The script's main entry point, passing all parameters given. +main "$@" + diff --git a/build/templates/license.c b/build/templates/license.c index bc0abc656f..6505a70ffd 100644 --- a/build/templates/license.c +++ b/build/templates/license.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2019, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/templates/license.h b/build/templates/license.h index bc0abc656f..6505a70ffd 100644 --- a/build/templates/license.h +++ b/build/templates/license.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2019, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/templates/license.sh b/build/templates/license.sh index ad5965c79b..b9c51e2892 100644 --- a/build/templates/license.sh +++ b/build/templates/license.sh @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2019, The University of Texas at Austin +# Copyright (C) 2018, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/common.mk b/common.mk index 75e563878f..5f2d30c9bf 100644 --- a/common.mk +++ b/common.mk @@ -1,6 +1,6 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # @@ -118,7 +118,8 @@ get-noopt-cxxflags-for = $(strip $(CFLAGS_PRESET) \ get-refinit-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ -DBLIS_CNAME=$(1) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-refkern-cflags-for = $(strip $(call load-var-for,CROPTFLAGS,$(1)) \ @@ -126,46 +127,72 @@ get-refkern-cflags-for = $(strip $(call load-var-for,CROPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ $(COMPSIMDFLAGS) \ -DBLIS_CNAME=$(1) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-config-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-frame-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ $(call load-var-for,CKVECFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) # When compiling sandboxes, we use flags similar to those of general framework # source. This ensures that the same code can be linked and run across various -# sub-configurations. (If we switch to using refkern/kernel flags, we should -# prevent enabling sandboxes for umbrella families by verifying that -# config_list == config_name if --enable-sandbox is given.) +# sub-configurations. +get-addon-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ + $(call get-noopt-cflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) +get-addon-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ + $(call get-noopt-cxxflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) + +# When compiling sandboxes, we use flags similar to those of general framework +# source. This ensures that the same code can be linked and run across various +# sub-configurations. (NOTE: If we ever switch to using refkernel or kernel +# flags, we should prevent enabling sandboxes for umbrella families by verifying +# that config_list == config_name if --enable-sandbox is given. THIS ALSO +# APPLIES TO ADDONS ABOVE.) get-sandbox-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(CSBOXINCFLAGS) \ - $(BUILD_FLAGS) \ + $(CSANDINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-sandbox-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cxxflags-for,$(1)) \ - $(CSBOXINCFLAGS) \ - $(BUILD_FLAGS) \ + $(CSANDINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) # Define a separate function that will return appropriate flags for use by # applications that want to use the same basic flags as those used when BLIS -# was compiled. (This is the same as get-frame-cflags-for(), except that it -# omits the BUILD_FLAGS, which are exclusively for use when BLIS is being -# compiled.) +# was compiled. (NOTE: This is the same as the $(get-frame-cflags-for ...) +# function, except that it omits two variables that contain flags exclusively +# for use when BLIS is being compiled/built: BUILD_CPPFLAGS, which contains a +# cpp macro that confirms that BLIS is being built; and BUILD_SYMFLAGS, which +# contains symbol export flags that are only needed when a shared library is +# being compiled/linked.) get-user-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ ) @@ -178,6 +205,8 @@ get-refkern-text-for = "('$(1)' CFLAGS for ref. kernels)" get-config-text-for = "('$(1)' CFLAGS for config code)" get-frame-text-for = "('$(1)' CFLAGS for framework code)" get-kernel-text-for = "('$(1)' CFLAGS for kernels)" +get-addon-c99text-for = "('$(1)' CFLAGS for addons)" +get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" get-sandbox-c99text-for = "('$(1)' CFLAGS for sandboxes)" get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" @@ -192,6 +221,9 @@ get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" files-that-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f)),$(f),))) files-that-dont-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f)),,$(f)))) +# Define a function that removes duplicate strings *without* using the sort +# function. +rm-dups = $(if $1,$(firstword $1) $(call rm-dups,$(filter-out $(firstword $1),$1))) # @@ -277,6 +309,7 @@ CONFIG_DIR := config FRAME_DIR := frame REFKERN_DIR := ref_kernels KERNELS_DIR := kernels +ADDON_DIR := addon SANDBOX_DIR := sandbox OBJ_DIR := obj LIB_DIR := lib @@ -284,16 +317,22 @@ INCLUDE_DIR := include BLASTEST_DIR := blastest TESTSUITE_DIR := testsuite +VEND_DIR := vendor +VEND_CPP_DIR := $(VEND_DIR)/cpp +VEND_TESTCPP_DIR := $(VEND_DIR)/testcpp + # The filename suffix for reference kernels. REFNM := ref # Source suffixes. CONFIG_SRC_SUFS := c - KERNELS_SRC_SUFS := c s S - FRAME_SRC_SUFS := c +ADDON_C99_SUFS := c +ADDON_CXX_SUFS := cc cpp cxx +ADDON_SRC_SUFS := $(ADDON_C99_SUFS) $(ADDON_CXX_SUFS) + SANDBOX_C99_SUFS := c SANDBOX_CXX_SUFS := cc cpp cxx SANDBOX_SRC_SUFS := $(SANDBOX_C99_SUFS) $(SANDBOX_CXX_SUFS) @@ -301,15 +340,21 @@ SANDBOX_SRC_SUFS := $(SANDBOX_C99_SUFS) $(SANDBOX_CXX_SUFS) # Header suffixes. FRAME_HDR_SUFS := h +ADDON_H99_SUFS := h +ADDON_HXX_SUFS := hh hpp hxx +ADDON_HDR_SUFS := $(ADDON_H99_SUFS) $(ADDON_HXX_SUFS) + SANDBOX_H99_SUFS := h SANDBOX_HXX_SUFS := hh hpp hxx SANDBOX_HDR_SUFS := $(SANDBOX_H99_SUFS) $(SANDBOX_HXX_SUFS) # Combine all header suffixes and remove duplicates via sort(). ALL_HDR_SUFS := $(sort $(FRAME_HDR_SUFS) \ + $(ADDON_HDR_SUFS) \ $(SANDBOX_HDR_SUFS) ) ALL_H99_SUFS := $(sort $(FRAME_HDR_SUFS) \ + $(ADDON_HDR_SUFS) \ $(SANDBOX_H99_SUFS) ) # The names of scripts that check output from the BLAS test drivers and @@ -336,13 +381,19 @@ SHELL := bash # Construct paths to the four primary directories of source code: # the config directory, general framework code, reference kernel code, -# and optimized kernel code. +# and optimized kernel code. Also process paths for addon and sandbox +# directories. CONFIG_PATH := $(DIST_PATH)/$(CONFIG_DIR) FRAME_PATH := $(DIST_PATH)/$(FRAME_DIR) REFKERN_PATH := $(DIST_PATH)/$(REFKERN_DIR) KERNELS_PATH := $(DIST_PATH)/$(KERNELS_DIR) +ADDON_PATH := $(DIST_PATH)/$(ADDON_DIR) SANDBOX_PATH := $(DIST_PATH)/$(SANDBOX_DIR) +# Construct paths to some optional C++ template headers contributed by AMD. +VEND_CPP_PATH := $(DIST_PATH)/$(VEND_CPP_DIR) +VEND_TESTCPP_PATH := $(DIST_PATH)/$(VEND_TESTCPP_DIR) + # Construct paths to the makefile fragments for the four primary directories # of source code: the config directory, general framework code, reference # kernel code, and optimized kernel code. @@ -350,6 +401,7 @@ CONFIG_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(CONFIG_DIR) FRAME_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(FRAME_DIR) REFKERN_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(REFKERN_DIR) KERNELS_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(KERNELS_DIR) +ADDON_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(ADDON_DIR) SANDBOX_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(SANDBOX_DIR) @@ -477,7 +529,8 @@ LIBMEMKIND := -lmemkind # Default linker flags. # NOTE: -lpthread is needed unconditionally because BLIS uses pthread_once() -# to initialize itself in a thread-safe manner. +# to initialize itself in a thread-safe manner. The one exception to this +# rule: if --disable-system is given at configure-time, LIBPTHREAD is empty. LDFLAGS := $(LDFLAGS_PRESET) $(LIBM) $(LIBPTHREAD) # Add libmemkind to the link-time flags, if it was enabled at configure-time. @@ -500,7 +553,11 @@ endif ifeq ($(OS_NAME),Darwin) # OS X shared library link flags. SOFLAGS := -dynamiclib -SOFLAGS += -Wl,-install_name,$(LIBBLIS_SONAME) +ifeq ($(MK_ENABLE_RPATH),yes) +SOFLAGS += -Wl,-install_name,@rpath/$(LIBBLIS_SONAME) +else +SOFLAGS += -Wl,-install_name,$(libdir)/$(LIBBLIS_SONAME) +endif else SOFLAGS := -shared ifeq ($(IS_WIN),yes) @@ -527,9 +584,24 @@ LIBBLIS_L := $(LIBBLIS_SO) LIBBLIS_LINK := $(LIBBLIS_SO_PATH) ifeq ($(IS_WIN),no) # For Linux and OS X: set rpath property of shared object. -LDFLAGS += -Wl,-rpath,$(BASE_LIB_PATH) +ifeq ($(OS_NAME),Darwin) +# rpath for test_libblis.x +LDFLAGS += -Wl,-rpath,@executable_path/$(BASE_LIB_PATH) +# rpath for BLAS tests +LDFLAGS += -Wl,-rpath,@executable_path/../../../$(BASE_LIB_PATH) +else +# rpath for test_libblis.x +LDFLAGS += -Wl,-rpath,'$$ORIGIN/$(BASE_LIB_PATH)' +# rpath for BLAS tests +LDFLAGS += -Wl,-rpath,'$$ORIGIN/../../../$(BASE_LIB_PATH)' +endif endif endif +# On windows, use the shared library even if static is created. +ifeq ($(IS_WIN),yes) +LIBBLIS_L := $(LIBBLIS_SO) +LIBBLIS_LINK := $(LIBBLIS_SO_PATH) +endif endif @@ -603,12 +675,12 @@ endif # Disable tautological comparision warnings in clang. ifeq ($(CC_VENDOR),clang) -CWARNFLAGS += -Wno-tautological-compare +CWARNFLAGS += -Wno-tautological-compare -Wno-pass-failed endif $(foreach c, $(CONFIG_LIST_FAM), $(eval $(call append-var-for,CWARNFLAGS,$(c)))) -# --- Shared library (position-independent code) flags --- +# --- Position-independent code flags (shared libraries only) --- # Emit position-independent code for dynamic linking. ifeq ($(IS_WIN),yes) @@ -620,6 +692,71 @@ CPICFLAGS := -fPIC endif $(foreach c, $(CONFIG_LIST_FAM), $(eval $(call append-var-for,CPICFLAGS,$(c)))) +# --- Symbol exporting flags (shared libraries only) --- + +# NOTE: These flags are only applied when building BLIS and not used by +# applications that import BLIS compilation flags via the +# $(get-user-cflags-for ...) function. + +# Determine default export behavior / visibility of symbols for gcc. +ifeq ($(CC_VENDOR),gcc) +ifeq ($(IS_WIN),yes) +ifeq ($(EXPORT_SHARED),all) +BUILD_SYMFLAGS := -Wl,--export-all-symbols, -Wl,--enable-auto-import +else # ifeq ($(EXPORT_SHARED),public) +BUILD_SYMFLAGS := -Wl,--exclude-all-symbols +endif +else # ifeq ($(IS_WIN),no) +ifeq ($(EXPORT_SHARED),all) +# Export all symbols by default. +BUILD_SYMFLAGS := -fvisibility=default +else # ifeq ($(EXPORT_SHARED),public) +# Hide all symbols by default and export only those that have been annotated +# as needing to be exported. +BUILD_SYMFLAGS := -fvisibility=hidden +endif +endif +endif + +# Determine default export behavior / visibility of symbols for icc. +# NOTE: The Windows branches have been omitted since we currently make no +# effort to support Windows builds via icc (only gcc/clang via AppVeyor). +ifeq ($(CC_VENDOR),icc) +ifeq ($(EXPORT_SHARED),all) +# Export all symbols by default. +BUILD_SYMFLAGS := -fvisibility=default +else # ifeq ($(EXPORT_SHARED),public) +# Hide all symbols by default and export only those that have been annotated +# as needing to be exported. +BUILD_SYMFLAGS := -fvisibility=hidden +endif +endif + +# Determine default export behavior / visibility of symbols for clang. +ifeq ($(CC_VENDOR),clang) +ifeq ($(IS_WIN),yes) +ifeq ($(EXPORT_SHARED),all) +# NOTE: clang on Windows does not appear to support exporting all symbols +# by default, and therefore we ignore the value of EXPORT_SHARED. +BUILD_SYMFLAGS := +else # ifeq ($(EXPORT_SHARED),public) +# NOTE: The default behavior of clang on Windows is to hide all symbols +# and only export functions and other declarations that have beenannotated +# as needing to be exported. +BUILD_SYMFLAGS := +endif +else # ifeq ($(IS_WIN),no) +ifeq ($(EXPORT_SHARED),all) +# Export all symbols by default. +BUILD_SYMFLAGS := -fvisibility=default +else # ifeq ($(EXPORT_SHARED),public) +# Hide all symbols by default and export only those that have been annotated +# as needing to be exported. +BUILD_SYMFLAGS := -fvisibility=hidden +endif +endif +endif + # --- Language flags --- # Enable C99. @@ -638,6 +775,10 @@ $(foreach c, $(CONFIG_LIST_FAM), $(eval $(call append-var-for,CPPROCFLAGS,$(c))) # --- Threading flags --- +# NOTE: We don't have to explicitly omit -pthread when --disable-system is given +# since that option forces --enable-threading=none, and thus -pthread never gets +# added to begin with. + ifeq ($(CC_VENDOR),gcc) ifeq ($(THREADING_MODEL),auto) THREADING_MODEL := openmp @@ -683,8 +824,18 @@ endif # --- #pragma omp simd flags (used for reference kernels only) --- ifeq ($(PRAGMA_OMP_SIMD),yes) +ifeq ($(CC_VENDOR),gcc) +COMPSIMDFLAGS := -fopenmp-simd +else +ifeq ($(CC_VENDOR),clang) COMPSIMDFLAGS := -fopenmp-simd else +ifeq ($(CC_VENDOR),icc) +COMPSIMDFLAGS := -qopenmp-simd +endif +endif +endif +else # ifeq ($(PRAGMA_OMP_SIMD),no) COMPSIMDFLAGS := endif @@ -720,9 +871,6 @@ endif # --- LDFLAGS cleanup ---------------------------------------------------------- # -# Remove duplicate flags/options in LDFLAGS (such as -lpthread) by sorting. -LDFLAGS := $(sort $(LDFLAGS)) - # @@ -742,6 +890,7 @@ MK_CONFIG_SRC := MK_KERNELS_SRC := MK_REFKERN_SRC := MK_FRAME_SRC := +MK_ADDON_SRC := MK_SANDBOX_SRC := # -- config -- @@ -792,6 +941,24 @@ PARENT_PATH := $(OBJ_DIR)/$(CONFIG_NAME) -include $(addsuffix /$(FRAGMENT_MK), $(REFKERN_FRAG_PATH)) -include $(addsuffix /$(FRAGMENT_MK), $(FRAME_FRAG_PATH)) +# -- addon -- + +# Construct paths to each addon. +# NOTE: If $(ADDON_LIST) is empty (because no addon was enabled at configure- +# time) then $(ADDON_PATHS) will also be empty, which will cause no fragments +# to be included. +ADDON_PATHS := $(addprefix $(ADDON_FRAG_PATH)/, $(ADDON_LIST)) + +# This variable is used by the include statements as they recursively include +# one another. For the 'addons' directory, we initialize it to that directory +# in preparation to include the fragments in the configuration sub-directory. +PARENT_SRC_PATH := $(ADDON_PATH) +PARENT_PATH := $(ADDON_FRAG_PATH) + +# Recursively include the makefile fragments in each of the addons sub- +# directories. +-include $(addsuffix /$(FRAGMENT_MK), $(ADDON_PATHS)) + # -- sandbox -- # Construct paths to each sandbox. (At present, there can be only one.) @@ -809,6 +976,8 @@ PARENT_PATH := $(SANDBOX_FRAG_PATH) # Recursively include the makefile fragments in the sandbox sub-directory. -include $(addsuffix /$(FRAGMENT_MK), $(SANDBOX_PATHS)) +# -- post-processing -- + # Create a list of the makefile fragments using the variable into which each # of the above include statements accumulated their directory paths. MAKEFILE_FRAGMENTS := $(addsuffix /$(FRAGMENT_MK), $(FRAGMENT_DIR_PATHS)) @@ -827,14 +996,14 @@ endif # # Define a function that will expand all of the directory paths given in $(1) -# to actual filepaths using the list of suffixes provided $(2). +# to actual filepaths using the list of suffixes provided in $(2). get-filepaths = $(strip $(foreach path, $(1), \ $(foreach suf, $(2), \ $(wildcard $(path)/*.$(suf)) \ ) ) ) # Define a function that will expand all of the directory paths given in $(1) -# to actual filepaths using the list of suffixes provided $(2), taking only +# to actual filepaths using the list of suffixes provided in $(2), taking only # the first expansion from each directory with at least one file matching # the current suffix. Finally, strip the filenames from all resulting files, # returning only the directory paths. @@ -844,20 +1013,29 @@ get-dirpaths = $(dir $(foreach path, $(1), \ $(wildcard $(path)/*.$(suf)) \ ) ) ) ) -# We'll use two directory lists. The first is a list of all of the directories -# in which makefile fragments were generated (plus the current directory). The -# second is the subset of the first that begins with the sandbox root path. +# We'll use three directory lists. The first is a list of all of the directories +# in which makefile fragments were generated, plus the current directory. (The +# current directory is needed so we include bli_config.h and bli_addon.h in the +# processing of header files.) The second and third are subsets of the first +# that begins with the addon and sandbox root paths, respectively. ALLFRAG_DIR_PATHS := . $(FRAGMENT_DIR_PATHS) +ADDON_DIR_PATHS := $(filter $(ADDON_PATH)/%,$(ALLFRAG_DIR_PATHS)) SANDBOX_DIR_PATHS := $(filter $(SANDBOX_PATH)/%,$(ALLFRAG_DIR_PATHS)) ALL_H99_FILES := $(call get-filepaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) -FRAME_H99_FILES := $(filter-out $(SANDBOX_PATH)/%,$(ALL_H99_FILES)) +FRAME_H99_FILES := $(filter-out $(ADDON_PATH)/%, \ + $(filter-out $(SANDBOX_PATH)/%, \ + $(ALL_H99_FILES) \ + ) ) -ALL_H99_DIRPATHS := $(call get-dirpaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) +ALL_H99_DIRPATHS := $(call get-dirpaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) -SANDBOX_H99_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_H99_SUFS)) -SANDBOX_HXX_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_HXX_SUFS)) +ADDON_H99_FILES := $(call get-filepaths,$(ADDON_DIR_PATHS),$(ADDON_H99_SUFS)) +ADDON_HXX_FILES := $(call get-filepaths,$(ADDON_DIR_PATHS),$(ADDON_HXX_SUFS)) +ADDON_HDR_DIRPATHS := $(call get-dirpaths,$(ADDON_DIR_PATHS),$(ALL_HDR_SUFS)) +SANDBOX_H99_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_H99_SUFS)) +SANDBOX_HXX_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_HXX_SUFS)) SANDBOX_HDR_DIRPATHS := $(call get-dirpaths,$(SANDBOX_DIR_PATHS),$(ALL_HDR_SUFS)) @@ -896,9 +1074,11 @@ BLIS_H_FLAT := $(BASE_INC_PATH)/$(BLIS_H) # # Isolate the path to cblas.h by filtering the file from the list of framework -# header files. +# header files, and then strip the filename to obtain the directory in which +# cblas.h resides. CBLAS_H := cblas.h CBLAS_H_SRC_PATH := $(filter %/$(CBLAS_H), $(FRAME_H99_FILES)) +CBLAS_H_DIRPATH := $(dir $(CBLAS_H_SRC_PATH)) # Construct the path to what will be the intermediate flattened/monolithic # cblas.h file. @@ -910,8 +1090,8 @@ CBLAS_H_FLAT := $(BASE_INC_PATH)/$(CBLAS_H) # # Obtain a list of header files #included inside of the bli_cntx_ref.c file. -# Paths to these files will be needed when compiling with the monolithic -# header. +# Due to the way that bli_cntx_ref.c uses headers and macros, paths to these +# files will be needed when compiling bli_cntx_ref.c with the monolithic header. ifeq ($(strip $(SHARE_PATH)),.) REF_KER_SRC := $(DIST_PATH)/$(REFKERN_DIR)/bli_cntx_ref.c REF_KER_HEADERS := $(shell $(GREP) "\#include" $(REF_KER_SRC) | sed -e "s/\#include [\"<]\([a-zA-Z0-9\_\.\/\-]*\)[\">].*/\1/g" | $(GREP) -v $(BLIS_H)) @@ -919,12 +1099,14 @@ endif # Match each header found above with the path to that header, and then strip # leading, trailing, and internal whitespace. -REF_KER_H_PATHS := $(strip $(foreach header, $(REF_KER_HEADERS), \ - $(dir $(filter %/$(header), \ - $(FRAME_H99_FILES))))) +REF_KER_H_PATHS := $(call rm-dups,$(strip \ + $(foreach header, $(REF_KER_HEADERS), \ + $(dir $(filter %/$(header), \ + $(FRAME_H99_FILES)))))) # Add -I to each header path so we can specify our include search paths to the -# C compiler. Then add frame/include since it's needed for bli_oapi_w[o]_cntx.h. +# C compiler. Then add frame/include since it's needed when compiling source +# files that #include bli_oapi_ba.h or bli_oapi_ex.h. REF_KER_I_PATHS := $(strip $(patsubst %, -I%, $(REF_KER_H_PATHS))) REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include @@ -933,17 +1115,29 @@ REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include # now #include the monolithic/flattened blis.h instead. CINCFLAGS := -I$(BASE_INC_PATH) $(REF_KER_I_PATHS) +# If CBLAS is enabled, we also include the path to the cblas.h directory so +# that the compiler will be able to find cblas.h as the CBLAS source code is +# being compiled. +ifeq ($(MK_ENABLE_CBLAS),yes) +CINCFLAGS += -I$(CBLAS_H_DIRPATH) +endif + +# Obtain a list of header paths in the configured addons. Then add -I to each +# header path. +CADDONINCFLAGS := $(strip $(patsubst %, -I%, $(ADDON_HDR_DIRPATHS))) + # Obtain a list of header paths in the configured sandbox. Then add -I to each # header path. -CSBOXINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) +CSANDINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) # # --- BLIS configuration header definitions ------------------------------------ # -# This file was created by configure, but we need to define it here so we can -# remove it as part of the clean targets. +# These files were created by configure, but we need to define them here so we +# can remove them as part of the clean targets. +BLIS_ADDON_H := ./bli_addon.h BLIS_CONFIG_H := ./bli_config.h @@ -958,11 +1152,10 @@ VERS_DEF := -DBLIS_VERSION_STRING=\"$(VERSION)\" # Define a C preprocessor flag that is *only* defined when BLIS is being # compiled. (In other words, an application that #includes blis.h will not # get this cpp macro.) -BUILD_FLAGS := -DBLIS_IS_BUILDING_LIBRARY +BUILD_CPPFLAGS := -DBLIS_IS_BUILDING_LIBRARY # end of ifndef COMMON_MK_INCLUDED conditional block endif - diff --git a/config/a64fx/bli_a64fx_sector_cache.h b/config/a64fx/bli_a64fx_sector_cache.h new file mode 100644 index 0000000000..a81d04caca --- /dev/null +++ b/config/a64fx/bli_a64fx_sector_cache.h @@ -0,0 +1,117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + // A64FX: set up cache sizes + // + // Reference: A64FX (TM) specification Fujitsu HPC Extension + // Link: https://github.com/fujitsu/A64FX/blob/master/doc/A64FX_Specification_HPC_Extension_v1_EN.pdf + // + // 63:15 | 14:12 | 11 | 10:08 | 07 | 06:04 | 03 | 02:00 | + // RES0 | l1_sec3_max | RES0 | l1_sec2_max | RES0 | l1_sec1_max | RES0 | l1_sec0_max | + // + // the bits set number of maximum sectors from 0-7 + // 000 - 0 + // 001 - 1 + // 010 - 2 + // 011 - 3 + // 100 - 4 + // 101 - 5 + // 110 - 6 + // 111 - 7 + // + // For L1 we want to maximize the number of sectors for B + // Configuration 1: 1 sector for C (sector 3) + // 1 sector for A (sector 1) + // 6 sectors for B (sector 2) + // 0 sectors for the rest (sector 0) + // + // 16b bitfield conf. 1: 0b0 001 0 110 0 001 0 000 + // + // Configuration 2: 1 sector for C (sector 3) + // 1 sector for A (sector 1) + // 5 sectors for B (sector 2) + // 1 sectors for the rest (sector 0) + // + // 16b bitfield conf. 2: 0b0 001 0 101 0 001 0 001 + // + // accessing the control register: + // + // MRS , S3_3_C11_C8_2 + // MSR S3_3_C11_C8_2, + // + // TODO: First tests showed no change in performance, a deeper investigation + // is necessary +#define A64FX_SETUP_SECTOR_CACHE_SIZES(config_bitfield)\ +{\ + uint64_t sector_cache_config = config_bitfield;\ + __asm__ volatile(\ + "msr s3_3_c11_c8_2,%[sector_cache_config]"\ + :\ + : [sector_cache_config] "r" (sector_cache_config)\ + :\ + );\ +} + +#define A64FX_SETUP_SECTOR_CACHE_SIZES_L2(config_bitfield)\ +{\ + uint64_t sector_cache_config = config_bitfield;\ + __asm__ volatile(\ + "msr s3_3_c15_c8_2,%[sector_cache_config]"\ + :\ + : [sector_cache_config] "r" (sector_cache_config)\ + :\ + );\ +} + + +#define A64FX_SET_CACHE_SECTOR(areg, tag, sparereg)\ +" mov "#sparereg", "#tag" \n\t"\ +" lsl "#sparereg", "#sparereg", 56 \n\t"\ +" orr "#areg", "#areg", "#sparereg" \n\t" + +#define A64FX_READ_SECTOR_CACHE_SIZES(output_uint64)\ +__asm__ volatile(\ + "mrs %["#output_uint64"],s3_3_c11_c8_2"\ + : [output_uint64] "=r" (output_uint64)\ + : \ + :\ + ); + +#define A64FX_SCC(sec0,sec1,sec2,sec3)\ + (uint64_t)((sec0 & 0x7LU) | ((sec1 & 0x7LU) << 4) | ((sec2 & 0x7LU) << 8) | ((sec3 & 0x7LU) << 12)) + +#define A64FX_SCC_L2(sec02,sec13)\ + (uint64_t)((sec02 & 0x1FLU) | ((sec13 & 0x1FLU) << 8)) + diff --git a/config/a64fx/bli_cntx_init_a64fx.c b/config/a64fx/bli_cntx_init_a64fx.c new file mode 100644 index 0000000000..5132b2824c --- /dev/null +++ b/config/a64fx/bli_cntx_init_a64fx.c @@ -0,0 +1,154 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_a64fx_sector_cache.h" + +void bli_cntx_init_a64fx( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_a64fx_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 4, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_armsve_asm_2vx10_unindexed, FALSE, + cntx + ); + + // Set SVE-512 packing routine. + bli_cntx_set_packm_kers + ( + 2, + BLIS_PACKM_10XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_10xk, + // 12xk is not used and disabled for GCC 8-9 compatibility. + // BLIS_PACKM_12XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_int_12xk, + BLIS_PACKM_16XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_16xk, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 16, 8 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 10, 10, 10, 10 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 256, 128, 192, 96 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 2048, 2048, 1536, 1536 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 23040, 26880, 11520, 11760 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + +#if 0 + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 65, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 65, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 65, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 4, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, 10, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 16, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 120, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +#endif + + // Set A64FX cache sector sizes for each PE/CMG + // SC Fugaku might disable users' setting cache sizes. +#if !defined(CACHE_SECTOR_SIZE_READONLY) +#pragma omp parallel + { + A64FX_SETUP_SECTOR_CACHE_SIZES(A64FX_SCC(0,1,3,0)) + A64FX_SETUP_SECTOR_CACHE_SIZES_L2(A64FX_SCC_L2(9,28)) + } +#endif + +} + diff --git a/frame/1m/packm/bli_packm_cxk_3mis.h b/config/a64fx/bli_family_a64fx.h similarity index 76% rename from frame/1m/packm/bli_packm_cxk_3mis.h rename to config/a64fx/bli_family_a64fx.h index 358cdcee4e..f2837459d1 100644 --- a/frame/1m/packm/bli_packm_cxk_3mis.h +++ b/config/a64fx/bli_family_a64fx.h @@ -32,22 +32,25 @@ */ +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_cxk_3mis ) + +// -- MEMORY ALLOCATION -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 256 +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 + +// SVE-specific configs. +#define N_L1_SVE_DEFAULT 64 +#define W_L1_SVE_DEFAULT 4 +#define C_L1_SVE_DEFAULT 256 +#define N_L2_SVE_DEFAULT 2048 +#define W_L2_SVE_DEFAULT 16 +#define C_L2_SVE_DEFAULT 256 +#define N_L3_SVE_DEFAULT 8192 +#define W_L3_SVE_DEFAULT 16 +#define C_L3_SVE_DEFAULT 256 + +//#endif diff --git a/config/a64fx/make_defs.mk b/config/a64fx/make_defs.mk new file mode 100644 index 0000000000..d6871fac31 --- /dev/null +++ b/config/a64fx/make_defs.mk @@ -0,0 +1,82 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := a64fx +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := -D_GNU_SOURCE -D_A64FX +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 -ftree-vectorize -march=armv8-a+sve +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +CKVECFLAGS := + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/amd64/bli_family_amd64.h b/config/amd64/bli_family_amd64.h index 278c228182..4791cceeb5 100644 --- a/config/amd64/bli_family_amd64.h +++ b/config/amd64/bli_family_amd64.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,15 +32,8 @@ */ -//#ifndef BLIS_FAMILY_H -//#define BLIS_FAMILY_H +#ifndef BLIS_FAMILY_AMD64_H +#define BLIS_FAMILY_AMD64_H - -// -- MEMORY ALLOCATION -------------------------------------------------------- - -#define BLIS_SIMD_ALIGN_SIZE 16 - - - -//#endif +#endif diff --git a/config/amd64/make_defs.mk b/config/amd64/make_defs.mk index 70c0b692b4..ebb7a569fc 100644 --- a/config/amd64/make_defs.mk +++ b/config/amd64/make_defs.mk @@ -1,10 +1,10 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -57,28 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif -# Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) -ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver2 -else -ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver2 -else -$(error gcc or clang are required for this configuration.) -endif -endif - -# Flags specific to reference kernels. -CROPTFLAGS := $(CKOPTFLAGS) -ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -else -CRVECFLAGS := $(CKVECFLAGS) -endif +# Setting for reference and optimized kernels are taken from individual +# subconfiguration makefile fragments in this family. # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/amd64_legacy/bli_family_amd64_legacy.h b/config/amd64_legacy/bli_family_amd64_legacy.h new file mode 100644 index 0000000000..c4f84885f7 --- /dev/null +++ b/config/amd64_legacy/bli_family_amd64_legacy.h @@ -0,0 +1,42 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_FAMILY_AMD64_LEGACY_H +#define BLIS_FAMILY_AMD64_LEGACY_H + +// Placeholder for bundle configuration. + +#endif + diff --git a/windows/build/config.mk.in b/config/amd64_legacy/make_defs.mk similarity index 63% rename from windows/build/config.mk.in rename to config/amd64_legacy/make_defs.mk index 5b4ad8a278..37ccbdae22 100644 --- a/windows/build/config.mk.in +++ b/config/amd64_legacy/make_defs.mk @@ -1,52 +1,70 @@ -# -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - -# -# --- Configuration variable definitions --------------------------------------- -# -# Environment-related variables: -# REVISION - The code's revision number. -# PWD - The path to current working directory. -# ARCH_STR - A string to identify the requested build architecture. -# BUILD_STR - A string to identify the requested build type. -# CCOMPILER_STR - A string to identify the requested C compiler. -# -# Target-related variables: -# FLAMEC_OBJS - List of paths to flamec object files. -# LAPACK2FLAMEC_OBJS - List of paths to lapack2flamec object files. -# -# Note: these variables are not present in the .in template file. Instead, they -# are appended to the contents of the .in file by a build script and output to -# a separate file (by the same name, without the .in extension). -# +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := amd64_legacy +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O2 +endif + +# Setting for reference and optimized kernels are taken from individual +# subconfiguration makefile fragments in this family. + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/arm32/make_defs.mk b/config/arm32/make_defs.mk index b592851e52..e6818a19d7 100644 --- a/config/arm32/make_defs.mk +++ b/config/arm32/make_defs.mk @@ -61,7 +61,7 @@ COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -march=armv7-a else @@ -70,7 +70,15 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/arm64/bli_family_arm64.h b/config/arm64/bli_family_arm64.h index 278c228182..b242d70492 100644 --- a/config/arm64/bli_family_arm64.h +++ b/config/arm64/bli_family_arm64.h @@ -39,7 +39,18 @@ // -- MEMORY ALLOCATION -------------------------------------------------------- #define BLIS_SIMD_ALIGN_SIZE 16 - +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 + +// SVE-specific configs. +#define N_L1_SVE_DEFAULT 64 +#define W_L1_SVE_DEFAULT 4 +#define C_L1_SVE_DEFAULT 256 +#define N_L2_SVE_DEFAULT 2048 +#define W_L2_SVE_DEFAULT 16 +#define C_L2_SVE_DEFAULT 256 +#define N_L3_SVE_DEFAULT 8192 +#define W_L3_SVE_DEFAULT 16 +#define C_L3_SVE_DEFAULT 256 //#endif diff --git a/config/arm64/make_defs.mk b/config/arm64/make_defs.mk index ac1cd69739..fc1a062e68 100644 --- a/config/arm64/make_defs.mk +++ b/config/arm64/make_defs.mk @@ -57,20 +57,32 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -march=armv8-a else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -march=armv8-a +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/armsve/bli_cntx_init_armsve.c b/config/armsve/bli_cntx_init_armsve.c new file mode 100644 index 0000000000..ad0e682196 --- /dev/null +++ b/config/armsve/bli_cntx_init_armsve.c @@ -0,0 +1,169 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif + +void bli_cntx_init_armsve( cntx_t* cntx ) +{ + if (!(getauxval( AT_HWCAP ) & HWCAP_SVE)) + return; + + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; +#if 0 + blksz_t thresh[ BLIS_NUM_THRESH ]; +#endif + + // Set default kernel blocksizes and functions. + bli_cntx_init_armsve_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Block size. + dim_t m_r_s, n_r_s, k_c_s, m_c_s, n_c_s; + dim_t m_r_d, n_r_d, k_c_d, m_c_d, n_c_d; + dim_t m_r_c, n_r_c, k_c_c, m_c_c, n_c_c; + dim_t m_r_z, n_r_z, k_c_z, m_c_z, n_c_z; + bli_s_blksz_armsve(&m_r_s, &n_r_s, &k_c_s, &m_c_s, &n_c_s); + bli_d_blksz_armsve(&m_r_d, &n_r_d, &k_c_d, &m_c_d, &n_c_d); + bli_c_blksz_armsve(&m_r_c, &n_r_c, &k_c_c, &m_c_c, &n_c_c); + bli_z_blksz_armsve(&m_r_z, &n_r_z, &k_c_z, &m_c_z, &n_c_z); + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 4, + // These are vector-length agnostic kernels. Yet knowing mr is required at runtime. + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_armsve_asm_2vx10_unindexed, FALSE, + cntx + ); + + // Set VL-specific packing routines if applicable. + if (m_r_d==16) + bli_cntx_set_packm_kers + ( + 2, + BLIS_PACKM_10XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_10xk, + BLIS_PACKM_16XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_16xk, + cntx + ); + else if (m_r_d==8) + bli_cntx_set_packm_kers + ( + 1, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_armsve256_int_8xk, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], m_r_s, m_r_d, m_r_c, m_r_z ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], n_r_s, n_r_d, n_r_c, n_r_z ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], m_c_s, m_c_d, m_c_c, m_c_z ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], k_c_s, k_c_d, k_c_c, k_c_z ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], n_c_s, n_c_d, n_c_c, n_c_z ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + +#if 0 + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 101, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 101, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 101, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 4, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, n_r_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, m_r_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 120, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 2048, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +#endif +} + diff --git a/config/armsve/bli_family_armsve.h b/config/armsve/bli_family_armsve.h new file mode 100644 index 0000000000..f2837459d1 --- /dev/null +++ b/config/armsve/bli_family_armsve.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H + + +// -- MEMORY ALLOCATION -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 256 +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 + +// SVE-specific configs. +#define N_L1_SVE_DEFAULT 64 +#define W_L1_SVE_DEFAULT 4 +#define C_L1_SVE_DEFAULT 256 +#define N_L2_SVE_DEFAULT 2048 +#define W_L2_SVE_DEFAULT 16 +#define C_L2_SVE_DEFAULT 256 +#define N_L3_SVE_DEFAULT 8192 +#define W_L3_SVE_DEFAULT 16 +#define C_L3_SVE_DEFAULT 256 + +//#endif + diff --git a/config/armsve/make_defs.mk b/config/armsve/make_defs.mk new file mode 100644 index 0000000000..d3495efbb8 --- /dev/null +++ b/config/armsve/make_defs.mk @@ -0,0 +1,82 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := armsve +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := -D_GNU_SOURCE +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 -ftree-vectorize -march=armv8-a+sve +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +CKVECFLAGS := + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/bgq/make_defs.mk b/config/bgq/make_defs.mk index a577a9a32c..0cbbf439d5 100644 --- a/config/bgq/make_defs.mk +++ b/config/bgq/make_defs.mk @@ -68,18 +68,26 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),ibm) CKVECFLAGS := -qarch=qp -qtune=qp -qsimd=auto -qhot=level=1 -qprefetch -qunroll=yes -qnoipa endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Override the default value for LDFLAGS. ifeq ($(CC_VENDOR),ibm) diff --git a/config/bulldozer/make_defs.mk b/config/bulldozer/make_defs.mk index dec89a4c3e..1f80f2ab65 100644 --- a/config/bulldozer/make_defs.mk +++ b/config/bulldozer/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mfpmath=sse -mavx -mfma4 -march=bdver1 -mno-tbm -mno-xop -mno-lwp else @@ -75,10 +75,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/cortexa15/bli_cntx_init_cortexa15.c b/config/cortexa15/bli_cntx_init_cortexa15.c index 81568d59c4..7c6134ff01 100644 --- a/config/cortexa15/bli_cntx_init_cortexa15.c +++ b/config/cortexa15/bli_cntx_init_cortexa15.c @@ -55,11 +55,19 @@ void bli_cntx_init_cortexa15( cntx_t* cntx ) // Initialize level-3 blocksize objects with architecture-specific values. // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 4, 4, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 4, 4, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 336, 176, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 528, 368, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4096, 4096, 0, 0 ); +#if 1 + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 4, 4, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 4, 4, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 336, 176, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 528, 368, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4096, 4096, -1, -1 ); +#else + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, 4, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 4, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 176, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 368, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4096, -1, -1 ); +#endif // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. diff --git a/config/cortexa15/make_defs.mk b/config/cortexa15/make_defs.mk index ee4d301f4b..abbee599de 100644 --- a/config/cortexa15/make_defs.mk +++ b/config/cortexa15/make_defs.mk @@ -61,16 +61,24 @@ COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -march=armv7-a +CKVECFLAGS := -mcpu=cortex-a15 else $(error gcc is required for this configuration.) endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/cortexa53/make_defs.mk b/config/cortexa53/make_defs.mk index 9f723bcde3..b5b2220a67 100644 --- a/config/cortexa53/make_defs.mk +++ b/config/cortexa53/make_defs.mk @@ -57,20 +57,32 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 -ftree-vectorize -mtune=cortex-a53 +COPTFLAGS := -O2 -mcpu=cortex-a53 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -march=armv8-a+fp+simd -mcpu=cortex-a53 +CKVECFLAGS := -mcpu=cortex-a53 else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=cortex-a53 +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. -CROPTFLAGS := $(CKOPTFLAGS) +CROPTFLAGS := $(CKOPTFLAGS) -O3 +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/cortexa57/make_defs.mk b/config/cortexa57/make_defs.mk index 23bcf51e6e..83565b8a79 100644 --- a/config/cortexa57/make_defs.mk +++ b/config/cortexa57/make_defs.mk @@ -57,20 +57,32 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 -ftree-vectorize -mtune=cortex-a57 +COPTFLAGS := -O2 -mcpu=cortex-a57 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -march=armv8-a+fp+simd -mcpu=cortex-a57 +CKVECFLAGS := -mcpu=cortex-a57 else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=cortex-a57 +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/cortexa9/make_defs.mk b/config/cortexa9/make_defs.mk index 2adc40e307..ea9dc29ac6 100644 --- a/config/cortexa9/make_defs.mk +++ b/config/cortexa9/make_defs.mk @@ -61,16 +61,24 @@ COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -march=armv7-a +CKVECFLAGS := -mcpu=cortex-a9 else $(error gcc is required for this configuration.) endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/excavator/make_defs.mk b/config/excavator/make_defs.mk index deb85c79bd..6e73e60584 100644 --- a/config/excavator/make_defs.mk +++ b/config/excavator/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else @@ -75,10 +75,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/firestorm/bli_cntx_init_firestorm.c b/config/firestorm/bli_cntx_init_firestorm.c new file mode 100644 index 0000000000..a15ce03448 --- /dev/null +++ b/config/firestorm/bli_cntx_init_firestorm.c @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_cntx_init_firestorm( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_firestorm_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 2, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armv8a_asm_8x12, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8, FALSE, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 4, + BLIS_PACKM_8XK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_8xk, + BLIS_PACKM_12XK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_12xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_8xk, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 8, 6, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 120, 252, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 640, 3072, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 8192, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 99, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 99, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 99, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 8, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_armv8a_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_armv8a_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, 6, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 240, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 1024, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 3072, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} + diff --git a/config/firestorm/bli_family_firestorm.h b/config/firestorm/bli_family_firestorm.h new file mode 100644 index 0000000000..4a60ed2f2b --- /dev/null +++ b/config/firestorm/bli_family_firestorm.h @@ -0,0 +1,76 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H + + +// -- MEMORY ALLOCATION -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 16 + + +#if 0 +// -- LEVEL-3 MICRO-KERNEL CONSTANTS ------------------------------------------- + +#define BLIS_SGEMM_UKERNEL bli_sgemm_opt_8x12 +#define BLIS_DEFAULT_MR_S 8 +#define BLIS_DEFAULT_NR_S 12 +#define BLIS_DEFAULT_MC_S 120 //1536 //336 //416 // 1280 //160 // 160 // 160 //2048 //336 +#define BLIS_DEFAULT_KC_S 640 //1536 //336 //704 //1280 //672 //528 // 856 //2048 //528 +#define BLIS_DEFAULT_NC_S 3072 + +#define BLIS_DGEMM_UKERNEL bli_dgemm_opt_6x8 +#define BLIS_DEFAULT_MR_D 6 +#define BLIS_DEFAULT_NR_D 8 +#define BLIS_DEFAULT_MC_D 120 //1536 //160 //80 //176 +#define BLIS_DEFAULT_KC_D 240 //1536 //304 //336 //368 +#define BLIS_DEFAULT_NC_D 3072 + +#define BLIS_DEFAULT_MR_C 8 +#define BLIS_DEFAULT_NR_C 4 +#define BLIS_DEFAULT_MC_C 64 +#define BLIS_DEFAULT_KC_C 128 +#define BLIS_DEFAULT_NC_C 4096 + +#define BLIS_DEFAULT_MR_Z 8 +#define BLIS_DEFAULT_NR_Z 4 +#define BLIS_DEFAULT_MC_Z 64 +#define BLIS_DEFAULT_KC_Z 128 +#define BLIS_DEFAULT_NC_Z 4096 +#endif + + +//#endif + diff --git a/config/firestorm/make_defs.mk b/config/firestorm/make_defs.mk new file mode 100644 index 0000000000..dc4286e6a8 --- /dev/null +++ b/config/firestorm/make_defs.mk @@ -0,0 +1,82 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := firestorm +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := -D_GNU_SOURCE +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O2 -march=armv8-a +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize +CKVECFLAGS := -march=armv8-a + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/generic/make_defs.mk b/config/generic/make_defs.mk index 3388291da0..ee77b6cf0e 100644 --- a/config/generic/make_defs.mk +++ b/config/generic/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := else @@ -79,10 +79,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/haswell/bli_cntx_init_haswell.c b/config/haswell/bli_cntx_init_haswell.c index 0682e6933f..f2dc900ead 100644 --- a/config/haswell/bli_cntx_init_haswell.c +++ b/config/haswell/bli_cntx_init_haswell.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,9 +35,12 @@ #include "blis.h" +//GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + void bli_cntx_init_haswell( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; // Set default kernel blocksizes and functions. bli_cntx_init_haswell_ref( cntx ); @@ -63,12 +67,31 @@ void bli_cntx_init_haswell( cntx_t* cntx ) // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, cntx ); +#if 1 + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + cntx + ); +#endif + + // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( 4, @@ -85,9 +108,11 @@ void bli_cntx_init_haswell( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 10, + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, @@ -99,9 +124,11 @@ void bli_cntx_init_haswell( cntx_t* cntx ) // dotv BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, + // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + // scalv #if 0 BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, @@ -118,12 +145,18 @@ void bli_cntx_init_haswell( cntx_t* cntx ) #if 1 bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 1008, 1008, 1008, 1008 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, 75, 192 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); #else bli_blksz_init_easy( &blkszs[ BLIS_MR ], 16, 8, 8, 4 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 6, 6, 3, 3 ); -#endif - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 1024, 1024, 1024, 1024 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 112, 64, 56, 32 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 112, 72, 56, 44 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); +#endif bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, 8, 8 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, 8, 8 ); @@ -144,5 +177,81 @@ void bli_cntx_init_haswell( cntx_t* cntx ) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 201, 201, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 201, 201, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 201, 201, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + +#if 0 + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 1, + BLIS_GEMM, bli_gemmsup_ref, + cntx + ); +#endif + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 16, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, -1, -1, + 9, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/haswell/bli_family_haswell.h b/config/haswell/bli_family_haswell.h index dc75f01b29..58154692a7 100644 --- a/config/haswell/bli_family_haswell.h +++ b/config/haswell/bli_family_haswell.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +37,6 @@ //#define BLIS_FAMILY_H - #if 0 // -- LEVEL-3 MICRO-KERNEL CONSTANTS AND DEFINITIONS --------------------------- diff --git a/config/haswell/make_defs.mk b/config/haswell/make_defs.mk index f08d5a937e..a8135c1070 100644 --- a/config/haswell/make_defs.mk +++ b/config/haswell/make_defs.mk @@ -57,19 +57,25 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -O3 -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=haswell +ifeq ($(GCC_OT_4_9_0),yes) +# If gcc is older than 4.9.0, we must use a different label for -march. CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=core-avx2 +endif else ifeq ($(CC_VENDOR),icc) CKVECFLAGS := -xCORE-AVX2 else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=core-avx2 +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=haswell else $(error gcc, icc, or clang is required for this configuration.) endif @@ -79,10 +85,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) #-funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/intel64/make_defs.mk b/config/intel64/make_defs.mk index af462fdc3f..95f21f6f9c 100644 --- a/config/intel64/make_defs.mk +++ b/config/intel64/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 else @@ -79,10 +79,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/knc/bli_family_knc.h b/config/knc/bli_family_knc.h index 6f9e03e8fa..b968b0c9a1 100644 --- a/config/knc/bli_family_knc.h +++ b/config/knc/bli_family_knc.h @@ -46,8 +46,8 @@ #define BLIS_SIMD_ALIGN_SIZE 64 -#define BLIS_SIMD_SIZE 64 -#define BLIS_SIMD_NUM_REGISTERS 32 +#define BLIS_SIMD_MAX_SIZE 64 +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 #if 0 diff --git a/config/knc/make_defs.mk b/config/knc/make_defs.mk index be3c9019d8..0a1d43a645 100644 --- a/config/knc/make_defs.mk +++ b/config/knc/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),icc) CKVECFLAGS := else @@ -71,10 +71,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Override the default value for LDFLAGS. LDFLAGS := -mmic diff --git a/config/knl/bli_cntx_init_knl.c b/config/knl/bli_cntx_init_knl.c index e00b2a8dc5..6da3b7a3a9 100644 --- a/config/knl/bli_cntx_init_knl.c +++ b/config/knl/bli_cntx_init_knl.c @@ -79,9 +79,11 @@ void bli_cntx_init_knl( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 10, +#if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, +#endif // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, diff --git a/config/knl/bli_family_knl.h b/config/knl/bli_family_knl.h index 64994cd9dd..98d3fe8d72 100644 --- a/config/knl/bli_family_knl.h +++ b/config/knl/bli_family_knl.h @@ -52,8 +52,8 @@ #define BLIS_SIMD_ALIGN_SIZE 64 -#define BLIS_SIMD_SIZE 64 -#define BLIS_SIMD_NUM_REGISTERS 32 +#define BLIS_SIMD_MAX_SIZE 64 +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 /* #ifdef BLIS_NO_HBWMALLOC diff --git a/config/knl/make_defs.mk b/config/knl/make_defs.mk index b08cf1e4d5..d4b0da4aa0 100644 --- a/config/knl/make_defs.mk +++ b/config/knl/make_defs.mk @@ -57,7 +57,7 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif ifeq ($(DEBUG_TYPE),sde) @@ -73,7 +73,7 @@ MK_ENABLE_MEMKIND := no endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mavx512f -mavx512pf -mfpmath=sse -march=knl else @@ -99,13 +99,13 @@ endif # Note: We use AVX2 for reference kernels instead of AVX-512. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := -march=knl -mno-avx512f -mno-avx512pf -mno-avx512er -mno-avx512cd -funsafe-math-optimizations +CRVECFLAGS := -march=knl -mno-avx512f -mno-avx512pf -mno-avx512er -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast else ifeq ($(CC_VENDOR),icc) CRVECFLAGS := -xMIC-AVX512 else ifeq ($(CC_VENDOR),clang) -CRVECFLAGS := -march=knl -mno-avx512f -mno-avx512pf -mno-avx512er -mno-avx512cd +CRVECFLAGS := -march=knl -mno-avx512f -mno-avx512pf -mno-avx512er -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast else $(error gcc, icc, or clang is required for this configuration.) endif diff --git a/config/old/haswellbb/bli_cntx_init_haswell.c b/config/old/haswellbb/bli_cntx_init_haswell.c new file mode 100644 index 0000000000..9e1d03503a --- /dev/null +++ b/config/old/haswellbb/bli_cntx_init_haswell.c @@ -0,0 +1,276 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// Instantiate prototypes for packm kernels. +PACKM_KER_PROT( float, s, packm_6xk_bb4_haswell_ref ) +PACKM_KER_PROT( double, d, packm_6xk_bb2_haswell_ref ) + +// Instantiate prototypes for level-3 kernels. +GEMM_UKR_PROT( float, s, gemmbb_haswell_ref ) +GEMMTRSM_UKR_PROT( float, s, gemmtrsmbb_l_haswell_ref ) +GEMMTRSM_UKR_PROT( float, s, gemmtrsmbb_u_haswell_ref ) +TRSM_UKR_PROT( float, s, trsmbb_l_haswell_ref ) +TRSM_UKR_PROT( float, s, trsmbb_u_haswell_ref ) + +GEMM_UKR_PROT( double, d, gemmbb_haswell_ref ) +GEMMTRSM_UKR_PROT( double, d, gemmtrsmbb_l_haswell_ref ) +GEMMTRSM_UKR_PROT( double, d, gemmtrsmbb_u_haswell_ref ) +TRSM_UKR_PROT( double, d, trsmbb_l_haswell_ref ) +TRSM_UKR_PROT( double, d, trsmbb_u_haswell_ref ) + +GEMM_UKR_PROT( scomplex, c, gemmbb_haswell_ref ) +GEMMTRSM_UKR_PROT( scomplex, c, gemmtrsmbb_l_haswell_ref ) +GEMMTRSM_UKR_PROT( scomplex, c, gemmtrsmbb_u_haswell_ref ) +TRSM_UKR_PROT( scomplex, c, trsmbb_l_haswell_ref ) +TRSM_UKR_PROT( scomplex, c, trsmbb_u_haswell_ref ) + +GEMM_UKR_PROT( dcomplex, z, gemmbb_haswell_ref ) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsmbb_l_haswell_ref ) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsmbb_u_haswell_ref ) +TRSM_UKR_PROT( dcomplex, z, trsmbb_l_haswell_ref ) +TRSM_UKR_PROT( dcomplex, z, trsmbb_u_haswell_ref ) + +void bli_cntx_init_haswell( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_haswell_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( +#if 0 + 8, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, +#else + 12, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemmbb_haswell_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_FLOAT, bli_strsmbb_l_haswell_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_FLOAT, bli_strsmbb_u_haswell_ref, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemmbb_haswell_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_DOUBLE, bli_dtrsmbb_l_haswell_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_DOUBLE, bli_dtrsmbb_u_haswell_ref, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemmbb_haswell_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_SCOMPLEX, bli_ctrsmbb_l_haswell_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_SCOMPLEX, bli_ctrsmbb_u_haswell_ref, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemmbb_haswell_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_DCOMPLEX, bli_ztrsmbb_l_haswell_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_DCOMPLEX, bli_ztrsmbb_u_haswell_ref, FALSE, +#endif + cntx + ); + + // Update the context with customized virtual [gemm]trsm micro-kernels. + bli_cntx_set_l3_vir_ukrs + ( + 8, + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsmbb_l_haswell_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsmbb_u_haswell_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsmbb_l_haswell_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsmbb_u_haswell_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_SCOMPLEX, bli_cgemmtrsmbb_l_haswell_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_SCOMPLEX, bli_cgemmtrsmbb_u_haswell_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_DCOMPLEX, bli_zgemmtrsmbb_l_haswell_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_DCOMPLEX, bli_zgemmtrsmbb_u_haswell_ref, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 2, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_6xk_bb4_haswell_ref, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_6xk_bb2_haswell_ref, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 4, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 10, +#if 1 + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, +#endif + // axpyv +#if 0 + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int, +#else + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, +#endif + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + // scalv +#if 0 + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int, +#else + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, +#endif + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z +#if 0 + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, 75, 192 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); +#else + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 24, 12, 12, 6 ); + bli_blksz_init ( &blkszs[ BLIS_NR ], 6, 6, 6, 6, + 24, 12, 6, 6 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 4080, 2076 ); +#endif + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, 8, 8 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, 8, 8 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 1, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 1, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 1, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 8, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], -1, 6, -1, -1, + -1, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} + diff --git a/config/old/haswellbb/bli_family_haswell.h b/config/old/haswellbb/bli_family_haswell.h new file mode 100644 index 0000000000..06dfdfcfcc --- /dev/null +++ b/config/old/haswellbb/bli_family_haswell.h @@ -0,0 +1,170 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H + +#define BLIS_POOL_ADDR_ALIGN_SIZE_A 4096 +#define BLIS_POOL_ADDR_ALIGN_SIZE_B 4096 + +#define BLIS_POOL_ADDR_OFFSET_SIZE_A 32 +#define BLIS_POOL_ADDR_OFFSET_SIZE_B 64 + +// Disable right-side hemm, symm, and trmm[3] to accommodate the broadcasting of +// elements within the packed matrix B. +#define BLIS_DISABLE_HEMM_RIGHT +#define BLIS_DISABLE_SYMM_RIGHT +#define BLIS_DISABLE_TRMM_RIGHT +#define BLIS_DISABLE_TRMM3_RIGHT + + +#if 0 +// -- LEVEL-3 MICRO-KERNEL CONSTANTS AND DEFINITIONS --------------------------- + +// -- sgemm micro-kernel -- + +#if 0 +#define BLIS_SGEMM_UKERNEL bli_sgemm_asm_4x24 +#define BLIS_DEFAULT_MC_S 256 +#define BLIS_DEFAULT_KC_S 256 +#define BLIS_DEFAULT_NC_S 4080 +#define BLIS_DEFAULT_MR_S 4 +#define BLIS_DEFAULT_NR_S 24 + +#define BLIS_SGEMM_UKERNEL_PREFERS_CONTIG_ROWS +#endif + +#if 1 +#define BLIS_SGEMM_UKERNEL bli_sgemm_asm_6x16 +#define BLIS_DEFAULT_MC_S 144 +#define BLIS_DEFAULT_KC_S 256 +#define BLIS_DEFAULT_NC_S 4080 +#define BLIS_DEFAULT_MR_S 6 +#define BLIS_DEFAULT_NR_S 16 + +#define BLIS_SGEMM_UKERNEL_PREFERS_CONTIG_ROWS +#endif + +#if 0 +#define BLIS_SGEMM_UKERNEL bli_sgemm_asm_16x6 +#define BLIS_DEFAULT_MC_S 144 +#define BLIS_DEFAULT_KC_S 256 +#define BLIS_DEFAULT_NC_S 4080 +#define BLIS_DEFAULT_MR_S 16 +#define BLIS_DEFAULT_NR_S 6 +#endif + +// -- dgemm micro-kernel -- + +#if 0 +#define BLIS_DGEMM_UKERNEL bli_dgemm_asm_4x12 +#define BLIS_DEFAULT_MC_D 152 +#define BLIS_DEFAULT_KC_D 160 +#define BLIS_DEFAULT_NC_D 4080 +#define BLIS_DEFAULT_MR_D 4 +#define BLIS_DEFAULT_NR_D 12 + +#define BLIS_DGEMM_UKERNEL_PREFERS_CONTIG_ROWS +#endif + +#if 1 +#define BLIS_DGEMM_UKERNEL bli_dgemm_asm_6x8 +#define BLIS_DEFAULT_MC_D 72 +#define BLIS_DEFAULT_KC_D 256 +#define BLIS_DEFAULT_NC_D 4080 +#define BLIS_DEFAULT_MR_D 6 +#define BLIS_DEFAULT_NR_D 8 + +#define BLIS_DGEMM_UKERNEL_PREFERS_CONTIG_ROWS +#endif + +#if 0 +#define BLIS_DGEMM_UKERNEL bli_dgemm_asm_8x6 +#define BLIS_DEFAULT_MC_D 72 +#define BLIS_DEFAULT_KC_D 256 +#define BLIS_DEFAULT_NC_D 4080 +#define BLIS_DEFAULT_MR_D 8 +#define BLIS_DEFAULT_NR_D 6 +#endif + +// -- cgemm micro-kernel -- + +#if 1 +#define BLIS_CGEMM_UKERNEL bli_cgemm_asm_3x8 +#define BLIS_DEFAULT_MC_C 144 +#define BLIS_DEFAULT_KC_C 256 +#define BLIS_DEFAULT_NC_C 4080 +#define BLIS_DEFAULT_MR_C 3 +#define BLIS_DEFAULT_NR_C 8 + +#define BLIS_CGEMM_UKERNEL_PREFERS_CONTIG_ROWS +#endif + +#if 0 +#define BLIS_CGEMM_UKERNEL bli_cgemm_asm_8x3 +#define BLIS_DEFAULT_MC_C 144 +#define BLIS_DEFAULT_KC_C 256 +#define BLIS_DEFAULT_NC_C 4080 +#define BLIS_DEFAULT_MR_C 8 +#define BLIS_DEFAULT_NR_C 3 +#endif + +// -- zgemm micro-kernel -- + +#if 1 +#define BLIS_ZGEMM_UKERNEL bli_zgemm_asm_3x4 +#define BLIS_DEFAULT_MC_Z 72 +#define BLIS_DEFAULT_KC_Z 256 +#define BLIS_DEFAULT_NC_Z 4080 +#define BLIS_DEFAULT_MR_Z 3 +#define BLIS_DEFAULT_NR_Z 4 + +#define BLIS_ZGEMM_UKERNEL_PREFERS_CONTIG_ROWS +#endif + +#if 0 +#define BLIS_ZGEMM_UKERNEL bli_zgemm_asm_4x3 +#define BLIS_DEFAULT_MC_Z 72 +#define BLIS_DEFAULT_KC_Z 256 +#define BLIS_DEFAULT_NC_Z 4080 +#define BLIS_DEFAULT_MR_Z 4 +#define BLIS_DEFAULT_NR_Z 3 +#endif + +#endif + + +//#endif + diff --git a/config/old/haswellbb/make_defs.mk b/config/old/haswellbb/make_defs.mk new file mode 100644 index 0000000000..6752dde295 --- /dev/null +++ b/config/old/haswellbb/make_defs.mk @@ -0,0 +1,98 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := haswell +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=haswell +ifeq ($(GCC_OT_4_9_0),yes) +# If gcc is older than 4.9.0, we must use a different label for -march. +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=core-avx2 +endif +else +ifeq ($(CC_VENDOR),icc) +CKVECFLAGS := -xCORE-AVX2 +else +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=haswell +else +$(error gcc, icc, or clang is required for this configuration.) +endif +endif +endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/penryn/make_defs.mk b/config/penryn/make_defs.mk index 41d2d939fc..a3474e9ce7 100644 --- a/config/penryn/make_defs.mk +++ b/config/penryn/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 else @@ -79,10 +79,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/piledriver/make_defs.mk b/config/piledriver/make_defs.mk index bb23fbecea..ab42872fb3 100644 --- a/config/piledriver/make_defs.mk +++ b/config/piledriver/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver2 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else @@ -75,10 +75,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/power10/bli_cntx_init_power10.c b/config/power10/bli_cntx_init_power10.c new file mode 100644 index 0000000000..14c940f995 --- /dev/null +++ b/config/power10/bli_cntx_init_power10.c @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// Instantiate prototypes for packm kernels. +PACKM_KER_PROT( float, s, packm_6xk_bb4_power10_ref ) +PACKM_KER_PROT( double, d, packm_6xk_bb2_power10_ref ) + +// Instantiate prototypes for level-3 kernels. +GEMM_UKR_PROT( float, s, gemmbb_power10_ref ) +GEMMTRSM_UKR_PROT( float, s, gemmtrsmbb_l_power10_ref ) +GEMMTRSM_UKR_PROT( float, s, gemmtrsmbb_u_power10_ref ) +TRSM_UKR_PROT( float, s, trsmbb_l_power10_ref ) +TRSM_UKR_PROT( float, s, trsmbb_u_power10_ref ) + +GEMM_UKR_PROT( double, d, gemmbb_power10_ref ) +GEMMTRSM_UKR_PROT( double, d, gemmtrsmbb_l_power10_ref ) +GEMMTRSM_UKR_PROT( double, d, gemmtrsmbb_u_power10_ref ) +TRSM_UKR_PROT( double, d, trsmbb_l_power10_ref ) +TRSM_UKR_PROT( double, d, trsmbb_u_power10_ref ) + +GEMM_UKR_PROT( scomplex, c, gemmbb_power10_ref ) +GEMMTRSM_UKR_PROT( scomplex, c, gemmtrsmbb_l_power10_ref ) +GEMMTRSM_UKR_PROT( scomplex, c, gemmtrsmbb_u_power10_ref ) +TRSM_UKR_PROT( scomplex, c, trsmbb_l_power10_ref ) +TRSM_UKR_PROT( scomplex, c, trsmbb_u_power10_ref ) + +GEMM_UKR_PROT( dcomplex, z, gemmbb_power10_ref ) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsmbb_l_power10_ref ) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsmbb_u_power10_ref ) +TRSM_UKR_PROT( dcomplex, z, trsmbb_l_power10_ref ) +TRSM_UKR_PROT( dcomplex, z, trsmbb_u_power10_ref ) + +void bli_cntx_init_power10( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_power10_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 12, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_power10_mma_8x16, TRUE, + + BLIS_TRSM_L_UKR, BLIS_FLOAT, bli_strsmbb_l_power10_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_FLOAT, bli_strsmbb_u_power10_ref, FALSE, + + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_power10_mma_8x8, TRUE, + + BLIS_TRSM_L_UKR, BLIS_DOUBLE, bli_dtrsmbb_l_power10_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_DOUBLE, bli_dtrsmbb_u_power10_ref, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemmbb_power10_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_SCOMPLEX, bli_ctrsmbb_l_power10_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_SCOMPLEX, bli_ctrsmbb_u_power10_ref, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemmbb_power10_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_DCOMPLEX, bli_ztrsmbb_l_power10_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_DCOMPLEX, bli_ztrsmbb_u_power10_ref, FALSE, + cntx + ); + + // Update the context with customized virtual [gemm]trsm micro-kernels. + bli_cntx_set_l3_vir_ukrs + ( + 8, + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsmbb_l_power10_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsmbb_u_power10_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsmbb_l_power10_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsmbb_u_power10_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_SCOMPLEX, bli_cgemmtrsmbb_l_power10_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_SCOMPLEX, bli_cgemmtrsmbb_u_power10_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_DCOMPLEX, bli_zgemmtrsmbb_l_power10_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_DCOMPLEX, bli_zgemmtrsmbb_u_power10_ref, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 2, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_6xk_bb4_power10_ref, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_6xk_bb2_power10_ref, + cntx + ); + + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 8, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 832, 320, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 1026, 960, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4096, 4096, -1, -1 ); + + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + +} diff --git a/config/power10/bli_family_power10.h b/config/power10/bli_family_power10.h new file mode 100644 index 0000000000..4327738930 --- /dev/null +++ b/config/power10/bli_family_power10.h @@ -0,0 +1,39 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#define BLIS_POOL_ADDR_ALIGN_SIZE_A 4096 +#define BLIS_POOL_ADDR_ALIGN_SIZE_B 4096 + +#define BLIS_POOL_ADDR_OFFSET_SIZE_A 192 +#define BLIS_POOL_ADDR_OFFSET_SIZE_B 152 diff --git a/config/power10/make_defs.mk b/config/power10/make_defs.mk new file mode 100644 index 0000000000..2c3f7cd7b9 --- /dev/null +++ b/config/power10/make_defs.mk @@ -0,0 +1,83 @@ + +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2019, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := power10 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O2 +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) -O3 +ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mcpu=power10 -mtune=power10 +else +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=power10 -mtune=power10 +else +$(info $(CC_VENDOR)) +$(error gcc, clang is required for this configuration.) +endif +endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +CRVECFLAGS := $(CKVECFLAGS) + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/power7/make_defs.mk b/config/power7/make_defs.mk index 18f111bf68..f80774e48b 100644 --- a/config/power7/make_defs.mk +++ b/config/power7/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 -mtune=power7 +COPTFLAGS := -O2 -mtune=power7 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mvsx else @@ -70,7 +70,15 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/power9/bli_cntx_init_power9.c b/config/power9/bli_cntx_init_power9.c index 410569611c..4370ce26c1 100644 --- a/config/power9/bli_cntx_init_power9.c +++ b/config/power9/bli_cntx_init_power9.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,35 @@ #include "blis.h" +// Instantiate prototypes for packm kernels. +PACKM_KER_PROT( float, s, packm_6xk_bb4_power9_ref ) +PACKM_KER_PROT( double, d, packm_6xk_bb2_power9_ref ) + +// Instantiate prototypes for level-3 kernels. +GEMM_UKR_PROT( float, s, gemmbb_power9_ref ) +GEMMTRSM_UKR_PROT( float, s, gemmtrsmbb_l_power9_ref ) +GEMMTRSM_UKR_PROT( float, s, gemmtrsmbb_u_power9_ref ) +TRSM_UKR_PROT( float, s, trsmbb_l_power9_ref ) +TRSM_UKR_PROT( float, s, trsmbb_u_power9_ref ) + +GEMM_UKR_PROT( double, d, gemmbb_power9_ref ) +GEMMTRSM_UKR_PROT( double, d, gemmtrsmbb_l_power9_ref ) +GEMMTRSM_UKR_PROT( double, d, gemmtrsmbb_u_power9_ref ) +TRSM_UKR_PROT( double, d, trsmbb_l_power9_ref ) +TRSM_UKR_PROT( double, d, trsmbb_u_power9_ref ) + +GEMM_UKR_PROT( scomplex, c, gemmbb_power9_ref ) +GEMMTRSM_UKR_PROT( scomplex, c, gemmtrsmbb_l_power9_ref ) +GEMMTRSM_UKR_PROT( scomplex, c, gemmtrsmbb_u_power9_ref ) +TRSM_UKR_PROT( scomplex, c, trsmbb_l_power9_ref ) +TRSM_UKR_PROT( scomplex, c, trsmbb_u_power9_ref ) + +GEMM_UKR_PROT( dcomplex, z, gemmbb_power9_ref ) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsmbb_l_power9_ref ) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsmbb_u_power9_ref ) +TRSM_UKR_PROT( dcomplex, z, trsmbb_l_power9_ref ) +TRSM_UKR_PROT( dcomplex, z, trsmbb_u_power9_ref ) + void bli_cntx_init_power9( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; @@ -45,26 +74,65 @@ void bli_cntx_init_power9( cntx_t* cntx ) // Update the context with optimized native gemm micro-kernels and // their storage preferences. -// bli_cntx_set_l3_nat_ukrs -// ( -// 1, -// BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_power7_int_8x4, FALSE, -// cntx -// ); -/* - // Initialize level-3 blocksize objects with architecture-specific values. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 0, 8, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 0, 4, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 0, 64, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 0, 256, 0, 0 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 0, 4096, 0, 0 ); + bli_cntx_set_l3_nat_ukrs + ( + 12, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemmbb_power9_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_FLOAT, bli_strsmbb_l_power9_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_FLOAT, bli_strsmbb_u_power9_ref, FALSE, + + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_power9_asm_12x6, FALSE, + + BLIS_TRSM_L_UKR, BLIS_DOUBLE, bli_dtrsmbb_l_power9_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_DOUBLE, bli_dtrsmbb_u_power9_ref, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemmbb_power9_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_SCOMPLEX, bli_ctrsmbb_l_power9_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_SCOMPLEX, bli_ctrsmbb_u_power9_ref, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemmbb_power9_ref, FALSE, + BLIS_TRSM_L_UKR, BLIS_DCOMPLEX, bli_ztrsmbb_l_power9_ref, FALSE, + BLIS_TRSM_U_UKR, BLIS_DCOMPLEX, bli_ztrsmbb_u_power9_ref, FALSE, + cntx + ); + + // Update the context with customized virtual [gemm]trsm micro-kernels. + bli_cntx_set_l3_vir_ukrs + ( + 8, + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsmbb_l_power9_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsmbb_u_power9_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsmbb_l_power9_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsmbb_u_power9_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_SCOMPLEX, bli_cgemmtrsmbb_l_power9_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_SCOMPLEX, bli_cgemmtrsmbb_u_power9_ref, + BLIS_GEMMTRSM_L_UKR, BLIS_DCOMPLEX, bli_zgemmtrsmbb_l_power9_ref, + BLIS_GEMMTRSM_U_UKR, BLIS_DCOMPLEX, bli_zgemmtrsmbb_u_power9_ref, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 2, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_6xk_bb4_power9_ref, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_6xk_bb2_power9_ref, + cntx + ); + + + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, 12, -1, -1 ); + bli_blksz_init ( &blkszs[ BLIS_NR ], -1, 6, -1, -1, + -1, 12, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 576, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 1408, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 8190, -1, -1 ); + // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. bli_cntx_set_blkszs ( BLIS_NAT, 5, + // level-3 BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, @@ -72,6 +140,5 @@ void bli_cntx_init_power9( cntx_t* cntx ) BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, cntx ); -*/ -} +} diff --git a/config/power9/bli_family_power9.h b/config/power9/bli_family_power9.h index 202f1f854d..12b16444f4 100644 --- a/config/power9/bli_family_power9.h +++ b/config/power9/bli_family_power9.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,34 +32,15 @@ */ -//#ifndef BLIS_FAMILY_H -//#define BLIS_FAMILY_H +#define BLIS_POOL_ADDR_ALIGN_SIZE_A 4096 +#define BLIS_POOL_ADDR_ALIGN_SIZE_B 4096 -//#define BLIS_SIMD_NUM_REGISTERS 32 -//#define BLIS_SIMD_SIZE 64 -// -//#ifdef BLIS_NO_HBWMALLOC -// #include -// #define BLIS_MALLOC_POOL malloc -// #define BLIS_FREE_POOL free -//#else -// #include -// #define BLIS_MALLOC_POOL hbw_malloc -// #define BLIS_FREE_POOL hbw_free -//#endif - - -#if 0 -// -- LEVEL-3 MICRO-KERNEL CONSTANTS ------------------------------------------- - -#define BLIS_DGEMM_UKERNEL bli_dgemm_opt_8x4 -#define BLIS_DEFAULT_MR_D 8 -#define BLIS_DEFAULT_NR_D 4 -#define BLIS_DEFAULT_MC_D 64 -#define BLIS_DEFAULT_KC_D 256 -#define BLIS_DEFAULT_NC_D 4096 -#endif - - -//#endif +#define BLIS_POOL_ADDR_OFFSET_SIZE_A 192 +#define BLIS_POOL_ADDR_OFFSET_SIZE_B 152 +// Disable right-side hemm, symm, and trmm[3] to accommodate the broadcasting of +// elements within the packed matrix B. +#define BLIS_DISABLE_HEMM_RIGHT +#define BLIS_DISABLE_SYMM_RIGHT +#define BLIS_DISABLE_TRMM_RIGHT +#define BLIS_DISABLE_TRMM3_RIGHT diff --git a/config/power9/make_defs.mk b/config/power9/make_defs.mk index 3d66f60795..85fa592d84 100644 --- a/config/power9/make_defs.mk +++ b/config/power9/make_defs.mk @@ -1,10 +1,11 @@ + # # # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, The University of Texas at Austin # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -45,8 +46,8 @@ THIS_CONFIG := power9 # NOTE: The build system will append these variables with various # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -CPPROCFLAGS := -CMISCFLAGS := -mcpu=power9 +CPPROCFLAGS := +CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -57,15 +58,20 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 -funroll-loops +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := +CKVECFLAGS := -mcpu=power9 -mtune=power9 -DXLC=0 +else +ifeq ($(CC_VENDOR),IBM) +CKVECFLAGS := -qarch=pwr9 -qtune=pwr9 -DXLC=1 else -$(error gcc is required for this configuration.) +$(info $(CC_VENDOR)) +$(error gcc/xlc is required for this configuration.) +endif endif # Flags specific to reference kernels. diff --git a/config/sandybridge/make_defs.mk b/config/sandybridge/make_defs.mk index ba18e4f328..d3ceb34837 100644 --- a/config/sandybridge/make_defs.mk +++ b/config/sandybridge/make_defs.mk @@ -57,19 +57,23 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mavx -mfpmath=sse -march=sandybridge +ifeq ($(GCC_OT_4_9_0),yes) +# If gcc is older than 4.9.0, we must use a different label for -march. CKVECFLAGS := -mavx -mfpmath=sse -march=corei7-avx +endif else ifeq ($(CC_VENDOR),icc) CKVECFLAGS := -xAVX else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mavx -mfpmath=sse -march=corei7-avx +CKVECFLAGS := -mavx -mfpmath=sse -march=sandybridge else $(error gcc, icc, or clang is required for this configuration.) endif @@ -79,10 +83,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/skx/bli_cntx_init_skx.c b/config/skx/bli_cntx_init_skx.c index f030ca54b4..f18503a7a7 100644 --- a/config/skx/bli_cntx_init_skx.c +++ b/config/skx/bli_cntx_init_skx.c @@ -71,9 +71,11 @@ void bli_cntx_init_skx( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 10, +#if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, +#endif // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, @@ -104,8 +106,8 @@ void bli_cntx_init_skx( cntx_t* cntx ) bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, -1, -1 ); - bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 384, -1, -1, - 480, 480, -1, -1 ); + bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 256, -1, -1, + 480, 320, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 3752, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); diff --git a/config/skx/bli_family_skx.h b/config/skx/bli_family_skx.h index ac9478f8ba..d698f12b4d 100644 --- a/config/skx/bli_family_skx.h +++ b/config/skx/bli_family_skx.h @@ -47,8 +47,8 @@ #define BLIS_SIMD_ALIGN_SIZE 64 -#define BLIS_SIMD_SIZE 64 -#define BLIS_SIMD_NUM_REGISTERS 32 +#define BLIS_SIMD_MAX_SIZE 64 +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 //#include diff --git a/config/skx/make_defs.mk b/config/skx/make_defs.mk index 27bea5ef55..00ae94a364 100644 --- a/config/skx/make_defs.mk +++ b/config/skx/make_defs.mk @@ -57,11 +57,13 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -O3 -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -march=skylake-avx512 else @@ -69,7 +71,15 @@ ifeq ($(CC_VENDOR),icc) CKVECFLAGS := -xCORE-AVX512 else ifeq ($(CC_VENDOR),clang) +# NOTE: We have to use -march=haswell on Windows because apparently AVX512 +# uses an alternate calling convention where xmm registers are not callee-saved +# on the stack. When this is mixed with framework code compiled for general +# x86_64 mode then chaos ensues (e.g. #514). +ifeq ($(IS_WIN),yes) +CKVECFLAGS := -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -march=haswell +else CKVECFLAGS := -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -march=skylake-avx512 +endif else $(error gcc, icc, or clang is required for this configuration.) endif @@ -89,13 +99,21 @@ endif # to overcome the AVX-512 frequency drop". (Issue #187) CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations +CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast else ifeq ($(CC_VENDOR),icc) CRVECFLAGS := -xCORE-AVX2 else ifeq ($(CC_VENDOR),clang) -CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd +# NOTE: We have to use -march=haswell on Windows because apparently AVX512 +# uses an alternate calling convention where xmm registers are not callee-saved +# on the stack. When this is mixed with framework code compiled for general +# x86_64 mode then chaos ensues (e.g. #514). +ifeq ($(IS_WIN),yes) +CRVECFLAGS := -march=haswell -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast +endif else $(error gcc, icc, or clang is required for this configuration.) endif diff --git a/config/steamroller/make_defs.mk b/config/steamroller/make_defs.mk index a5b6707041..5220c3540b 100644 --- a/config/steamroller/make_defs.mk +++ b/config/steamroller/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver3 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else @@ -75,10 +75,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/template/kernels/1/bli_axpyv_template_noopt_var1.c b/config/template/kernels/1/bli_axpyv_template_noopt_var1.c index a061a60104..d1918466f7 100644 --- a/config/template/kernels/1/bli_axpyv_template_noopt_var1.c +++ b/config/template/kernels/1/bli_axpyv_template_noopt_var1.c @@ -105,7 +105,7 @@ void bli_zaxpyv_template_noopt dcomplex* xp; dcomplex* yp; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t n_pre = 0; dim_t n_iter; diff --git a/config/template/kernels/1/bli_dotv_template_noopt_var1.c b/config/template/kernels/1/bli_dotv_template_noopt_var1.c index f153c9516d..3761d2e764 100644 --- a/config/template/kernels/1/bli_dotv_template_noopt_var1.c +++ b/config/template/kernels/1/bli_dotv_template_noopt_var1.c @@ -112,7 +112,7 @@ void bli_zdotv_template_noopt dcomplex* yp; dcomplex dotxy; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t n_pre = 0; dim_t n_iter; diff --git a/config/template/kernels/1f/bli_axpy2v_template_noopt_var1.c b/config/template/kernels/1f/bli_axpy2v_template_noopt_var1.c index b66e6d85ed..7080abce06 100644 --- a/config/template/kernels/1f/bli_axpy2v_template_noopt_var1.c +++ b/config/template/kernels/1f/bli_axpy2v_template_noopt_var1.c @@ -116,7 +116,7 @@ void bli_zaxpy2v_template_noopt dcomplex* yp; dcomplex* zp; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t n_pre = 0; dim_t n_iter; diff --git a/config/template/kernels/1f/bli_axpyf_template_noopt_var1.c b/config/template/kernels/1f/bli_axpyf_template_noopt_var1.c index 38fabc0973..a0afedfcaf 100644 --- a/config/template/kernels/1f/bli_axpyf_template_noopt_var1.c +++ b/config/template/kernels/1f/bli_axpyf_template_noopt_var1.c @@ -126,7 +126,7 @@ void bli_zaxpyf_template_noopt dcomplex alpha_x[ bli_zaxpyf_fusefac ]; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t m_pre = 0; dim_t m_iter; diff --git a/config/template/kernels/1f/bli_dotaxpyv_template_noopt_var1.c b/config/template/kernels/1f/bli_dotaxpyv_template_noopt_var1.c index 5225c1953f..275c399982 100644 --- a/config/template/kernels/1f/bli_dotaxpyv_template_noopt_var1.c +++ b/config/template/kernels/1f/bli_dotaxpyv_template_noopt_var1.c @@ -123,7 +123,7 @@ void bli_zdotaxpyv_template_noopt dcomplex* zp; dcomplex dotxy; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t n_pre = 0; dim_t n_iter; diff --git a/config/template/kernels/1f/bli_dotxaxpyf_template_noopt_var1.c b/config/template/kernels/1f/bli_dotxaxpyf_template_noopt_var1.c index caf7351a45..6754d86ce8 100644 --- a/config/template/kernels/1f/bli_dotxaxpyf_template_noopt_var1.c +++ b/config/template/kernels/1f/bli_dotxaxpyf_template_noopt_var1.c @@ -145,7 +145,7 @@ void bli_zdotxaxpyf_template_noopt dcomplex At_w[ bli_zdotxaxpyf_fusefac ]; dcomplex alpha_x[ bli_zdotxaxpyf_fusefac ]; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t m_pre = 0; dim_t m_iter; diff --git a/config/template/kernels/1f/bli_dotxf_template_noopt_var1.c b/config/template/kernels/1f/bli_dotxf_template_noopt_var1.c index 4a14a7066b..430fb277db 100644 --- a/config/template/kernels/1f/bli_dotxf_template_noopt_var1.c +++ b/config/template/kernels/1f/bli_dotxf_template_noopt_var1.c @@ -129,7 +129,7 @@ void bli_zdotxf_template_noopt dcomplex Atx[ bli_zdotxf_fusefac ]; - bool_t use_ref = FALSE; + bool use_ref = FALSE; dim_t m_pre = 0; dim_t m_iter; diff --git a/config/template/kernels/3/bli_gemm_template_noopt_mxn.c b/config/template/kernels/3/bli_gemm_template_noopt_mxn.c index b7a13f3b69..06f25a0e9e 100644 --- a/config/template/kernels/3/bli_gemm_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemm_template_noopt_mxn.c @@ -37,6 +37,8 @@ void bli_zgemm_template_noopt ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a1, @@ -88,8 +90,7 @@ void bli_zgemm_template_noopt dim_t l, j, i; - dcomplex ab[ bli_zmr * - bli_znr ]; + dcomplex ab[ mr * nr ]; dcomplex* abij; dcomplex ai, bj; @@ -137,16 +138,16 @@ void bli_zgemm_template_noopt if ( bli_zeq0( *beta ) ) { /* c11 := ab */ - bli_zcopys_mxn( mr, - nr, + bli_zcopys_mxn( m, + n, ab, rs_ab, cs_ab, c11, rs_c, cs_c ); } else { /* c11 := beta * c11 + ab */ - bli_zxpbys_mxn( mr, - nr, + bli_zxpbys_mxn( m, + n, ab, rs_ab, cs_ab, beta, c11, rs_c, cs_c ); diff --git a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c index da0cd3110f..87c21f7edf 100644 --- a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c @@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt */ const num_t dt = BLIS_DCOMPLEX; + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t rs_b = packnr; @@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt /* b11 = alpha * b11 - a10 * b01; */ bli_zgemm_template_noopt ( + mr, + nr, k, minus_one, a10, diff --git a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c index 09b3af9cee..0b4544ae1d 100644 --- a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c @@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt */ const num_t dt = BLIS_DCOMPLEX; + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t rs_b = packnr; @@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt /* b11 = alpha * b11 - a12 * b21; */ bli_zgemm_template_noopt ( + mr, + nr, k, minus_one, - a12, - b21, + a10, + b01, alpha, b11, rs_b, cs_b, data diff --git a/config/template/make_defs.mk b/config/template/make_defs.mk index 35edf71a1d..7b5b532a34 100644 --- a/config/template/make_defs.mk +++ b/config/template/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 CKVECFLAGS := # Flags specific to reference kernels. diff --git a/config/thunderx2/make_defs.mk b/config/thunderx2/make_defs.mk index 3227fe242b..b43fea87c5 100644 --- a/config/thunderx2/make_defs.mk +++ b/config/thunderx2/make_defs.mk @@ -57,20 +57,32 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 -ftree-vectorize -mtune=thunderx2t99 +COPTFLAGS := -O2 -mcpu=thunderx2t99 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -march=armv8.1-a+fp+simd -mcpu=thunderx2t99 +CKVECFLAGS := -mcpu=thunderx2t99 else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=thunderx2t99 +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := $(CKVECFLAGS) +endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/x86_64/make_defs.mk b/config/x86_64/make_defs.mk index 4d038ff04b..6a05a1f8f9 100644 --- a/config/x86_64/make_defs.mk +++ b/config/x86_64/make_defs.mk @@ -57,11 +57,11 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 endif # Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 else @@ -79,10 +79,14 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/x86_64_no_skx/make_defs.mk b/config/x86_64_no_skx/make_defs.mk index e1e4bf4649..d58e5182ba 100644 --- a/config/x86_64_no_skx/make_defs.mk +++ b/config/x86_64_no_skx/make_defs.mk @@ -79,10 +79,16 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast else CRVECFLAGS := $(CKVECFLAGS) endif +endif + + # Store all of the variables here to new variables containing the # configuration name. diff --git a/windows/build/bli_kernel.h b/config/x86_64_no_zen2/bli_family_x86_64_no_zen2.h similarity index 96% rename from windows/build/bli_kernel.h rename to config/x86_64_no_zen2/bli_family_x86_64_no_zen2.h index daca58e454..21b44db870 100644 --- a/windows/build/bli_kernel.h +++ b/config/x86_64_no_zen2/bli_family_x86_64_no_zen2.h @@ -32,12 +32,10 @@ */ -#ifndef BLIS_KERNEL_H -#define BLIS_KERNEL_H +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H - - -#endif +//#endif diff --git a/config/x86_64_no_zen2/make_defs.mk b/config/x86_64_no_zen2/make_defs.mk new file mode 100644 index 0000000000..c70d1245fa --- /dev/null +++ b/config/x86_64_no_zen2/make_defs.mk @@ -0,0 +1,96 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := x86_64_no_zen2 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 +else +ifeq ($(CC_VENDOR),icc) +CKVECFLAGS := -xSSE3 +else +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 +else +$(error gcc, icc, or clang is required for this configuration.) +endif +endif +endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + + + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/frame/3/herk/bli_herk.h b/config/x86_64_no_zen3/bli_family_x86_64_no_zen3.h similarity index 96% rename from frame/3/herk/bli_herk.h rename to config/x86_64_no_zen3/bli_family_x86_64_no_zen3.h index c437289688..21b44db870 100644 --- a/frame/3/herk/bli_herk.h +++ b/config/x86_64_no_zen3/bli_family_x86_64_no_zen3.h @@ -32,7 +32,10 @@ */ -#include "bli_herk_front.h" +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H -#include "bli_herk_var.h" + + +//#endif diff --git a/config/x86_64_no_zen3/make_defs.mk b/config/x86_64_no_zen3/make_defs.mk new file mode 100644 index 0000000000..17ba13c099 --- /dev/null +++ b/config/x86_64_no_zen3/make_defs.mk @@ -0,0 +1,96 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := x86_64_no_zen3 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 +else +ifeq ($(CC_VENDOR),icc) +CKVECFLAGS := -xSSE3 +else +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mssse3 -mfpmath=sse -march=core2 +else +$(error gcc, icc, or clang is required for this configuration.) +endif +endif +endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + + + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/zen/amd_config.mk b/config/zen/amd_config.mk new file mode 100644 index 0000000000..def1cadbae --- /dev/null +++ b/config/zen/amd_config.mk @@ -0,0 +1,83 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# All the common flags for AMD architectures will be added here + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O2 -fomit-frame-pointer +endif + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -O3 +ifeq ($(CC_VENDOR),gcc) +CKVECFLAGS := -mavx2 -mfpmath=sse -mfma +else +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mavx2 -mfpmath=sse -mfma +ifeq ($(strip $(shell clang -v |& head -1 | grep -c 'AOCC.LLVM')),1) +CKVECFLAGS += -mllvm -disable-licm-vrp +endif +else +$(error gcc or clang are required for this configuration.) +endif +endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 41de6fd4e3..1b16cd06fc 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,9 +35,12 @@ #include "blis.h" +//GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + void bli_cntx_init_zen( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; // Set default kernel blocksizes and functions. bli_cntx_init_zen_ref( cntx ); @@ -48,40 +52,66 @@ void bli_cntx_init_zen( cntx_t* cntx ) bli_cntx_set_l3_nat_ukrs ( 8, + // gemm BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + + cntx + ); + +#if 1 + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, cntx ); +#endif // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( 4, + // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, + // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx ); // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 10, + 16, + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, @@ -90,12 +120,21 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, #endif + +#if 1 + // copyv + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, +#endif + // dotv BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, + // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + // scalv #if 0 BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, @@ -104,6 +143,17 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, #endif + +#if 1 + // setv + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // swapv + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, +#endif + cntx ); @@ -111,15 +161,37 @@ void bli_cntx_init_zen( cntx_t* cntx ) // s d c z bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + +/* + Multi Instance performance improvement of DGEMM when binded to a CCX + In Multi instance each thread runs a sequential DGEMM. + + a) If BLIS is run in a multi-instance mode with + CPU freq 2.6/2.2 Ghz + DDR4 clock frequency 2400Mhz + mc = 240, kc = 512, and nc = 2040 + has better performance on EPYC server, over the default block sizes. + + b) If BLIS is run in Single Instance mode + mc = 510, kc = 1024 and nc = 4080 +*/ + #ifdef BLIS_ENABLE_ZEN_BLOCK_SIZES // Zen optmized level 3 cache block sizes - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 510, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 1024, 256, 256 ); + #if BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 1020, 510, 510, 255 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 1024, 1024, 1024, 1024 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 4080, 3056 ); + #else + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 2040, 1528 ); + #endif #else bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 4080, 3056 ); #endif - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); @@ -139,5 +211,115 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 1, + BLIS_GEMM, bli_gemmsup_ref, + //BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 16, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, +#if 0 + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, +#endif + +#if 0 + // NOTE: This set of kernels is likely broken and therefore disabled. + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, +#endif + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, -1, -1, + 9, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, -1, -1 ); +#if 0 + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); +#endif + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index c872a21eb6..da03bd7e42 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2016, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,9 +33,6 @@ */ -//#ifndef BLIS_FAMILY_H -//#define BLIS_FAMILY_H - // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. @@ -43,14 +40,40 @@ #define BLIS_THREAD_MAX_JR 1 #define BLIS_ENABLE_ZEN_BLOCK_SIZES -//#define BLIS_ENABLE_SMALL_MATRIX + +// Vanilla BLIS disables AMD's small matrix handling by default. +#if 0 +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM // This will select the threshold below which small matrix code will be called. #define BLIS_SMALL_MATRIX_THRES 700 #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 +#define BLIS_SMALL_MATRIX_THRES_TRSM 32768 //128(128+128) => m*(m+n) +#define BLIS_SMALL_MATRIX_A_THRES_TRSM 128 +#define BLIS_SMALL_MATRIX_A_THRES_M_GEMMT 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_GEMMT 128 + +//This macro will enable BLIS DGEMM to choose block sizes for a single instance mode +#define BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES 0 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES 250 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES 90 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO 22 +#endif -//#endif +#if 0 +// Allow the sup implementation to combine some small edge case iterations in +// the 2nd loop of the panel-block algorithm (MR) and/or the 2nd loop of the +// block-panel algorithm (NR) with the last full iteration that precedes it. +// NOTE: These cpp macros need to be explicitly set to an integer since they +// are used at compile-time to create unconditional branches or dead code +// regions. +#define BLIS_ENABLE_SUP_MR_EXT 1 +#define BLIS_ENABLE_SUP_NR_EXT 0 +#endif diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index 0397f60b7c..8bdafd5ca2 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -1,10 +1,10 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -57,32 +57,35 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O3 +COPTFLAGS := -O2 -fomit-frame-pointer endif -# Flags specific to optimized kernels. -CKOPTFLAGS := $(COPTFLAGS) +# Flags specific to optimized and reference kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -O3 +CROPTFLAGS := $(CKOPTFLAGS) +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast ifeq ($(CC_VENDOR),gcc) -# gcc 6.0 (clang 4.0) or later: -#CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=znver1 -# gcc 4.9 (clang 3.5) or later: -# possibly add zen-specific instructions: -mclzero -madx -mrdseed -mmwaitx -msha -mxsavec -mxsaves -mclflushopt -mpopcnt -CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp + ifeq ($(GCC_OT_6_1_0),yes) # gcc versions older than 6.1. + CVECFLAGS_VER := -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp + else + CVECFLAGS_VER := -march=znver1 -mno-avx256-split-unaligned-store + endif else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp + CVECFLAGS_VER := -march=znver1 else -$(error gcc or clang are required for this configuration.) +ifeq ($(CC_VENDOR),aocc) + CVECFLAGS_VER := -march=znver1 -mllvm -disable-licm-vrp +else + $(error gcc, clang, or aocc is required for this configuration.) endif endif - -# Flags specific to reference kernels. -CROPTFLAGS := $(CKOPTFLAGS) -ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -else -CRVECFLAGS := $(CKVECFLAGS) endif +CKVECFLAGS += $(CVECFLAGS_VER) +CRVECFLAGS += $(CVECFLAGS_VER) # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/zen/make_defs.mk.old b/config/zen/make_defs.mk.old new file mode 100644 index 0000000000..44c2ad18d6 --- /dev/null +++ b/config/zen/make_defs.mk.old @@ -0,0 +1,84 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# FLAGS that are specific to the 'zen' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# amd_config.mk. + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# Include the file containing common flags for all AMD architectures. +AMD_CONFIG_FILE := amd_config.mk +AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen +-include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + +ifeq ($(CC_VENDOR),gcc) +# If gcc is older than 6.1.0, we must use -march=bdver4 and then remove the +# Bulldozer instruction sets that were omitted from Zen. +# Additionally, if gcc is 4.9 (clang 3.5?) or newer, we may want to add +# Zen-specific instructions back into the mix: +# -mclzero -madx -mrdseed -mmwaitx -msha -mxsavec -mxsaves -mclflushopt -mpopcnt +ifeq ($(GCC_OT_6_1_0),yes) +CRVECFLAGS += -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp +CKVECFLAGS += -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp +else +# If gcc is at least 6.1.0, then we can specify the microarchitecture using +# the preferred option. +CRVECFLAGS += -march=znver1 +CKVECFLAGS += -march=znver1 +endif +else +ifeq ($(CC_VENDOR),clang) +# I couldn't find which versions of clang added support for -march=znver1, +# so we don't even bother attempting the differentiation that appears in the +# gcc branch above. +CRVECFLAGS += -march=znver1 +CKVECFLAGS += -march=znver1 +else +$(error gcc or clang are required for this configuration.) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/zen/old/bli_kernel.h b/config/zen/old/bli_kernel.h index 68b9e88e05..cd324fd9a7 100644 --- a/config/zen/old/bli_kernel.h +++ b/config/zen/old/bli_kernel.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c new file mode 100644 index 0000000000..ba728602bb --- /dev/null +++ b/config/zen2/bli_cntx_init_zen2.c @@ -0,0 +1,287 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_cntx_init_zen2( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen2_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 8, + + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + + cntx + ); + +#if 1 + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + cntx + ); +#endif + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 4, + + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 16, + + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + + //swap + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + //copy + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); +#if AOCL_BLIS_MULTIINSTANCE + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); +#else + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); +#endif + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z +#if 1 + bli_blksz_init_easy( &thresh[ BLIS_MT ], 500, 249, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 500, 249, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 500, 249, -1, -1 ); +#else + bli_blksz_init_easy( &thresh[ BLIS_MT ], 100000, 100000, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 100000, 100000, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 100000, 100000, -1, -1 ); +#endif + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + +#if 0 + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 1, + BLIS_GEMM, bli_gemmsup_ref, + cntx + ); +#endif + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 16, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, +#if 0 + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, +#endif + +#if 0 + // NOTE: This set of kernels is likely broken and therefore disabled. + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, +#endif + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, -1, -1, + 9, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} + diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h new file mode 100644 index 0000000000..d7adddf3c8 --- /dev/null +++ b/config/zen2/bli_family_zen2.h @@ -0,0 +1,86 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not paralleized. +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 + +// Vanilla BLIS disables AMD's small matrix handling by default. +#if 0 +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM + +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_THRES_TRSM 32768 //128(128+128) => m*(m+n) +#define BLIS_SMALL_MATRIX_A_THRES_TRSM 128 +#define BLIS_SMALL_MATRIX_A_THRES_M_GEMMT 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_GEMMT 128 + +#define BLIS_ENABLE_SMALL_MATRIX_ROME +#define BLIS_SMALL_MATRIX_THRES_ROME 400 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME 80 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M 40 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M 1000 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N 10 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME 150 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M 5 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N 130 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME 120 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M 10 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N 1200 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M 30 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N 280 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N 100 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME 110 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N 30 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME 120 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME_COL_PANEL_N 50 + +// When running HPL with pure MPI without DGEMM threading (Single-threaded +// BLIS), defining this macro as 1 yields better performance. +#define AOCL_BLIS_MULTIINSTANCE 0 +#endif + diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk new file mode 100644 index 0000000000..c14b8cba09 --- /dev/null +++ b/config/zen2/make_defs.mk @@ -0,0 +1,105 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen2 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O2 -fomit-frame-pointer +endif + +# Flags specific to optimized and reference kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -O3 +CROPTFLAGS := $(CKOPTFLAGS) +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +ifeq ($(CC_VENDOR),gcc) + ifeq ($(GCC_OT_6_1_0),yes) # gcc versions older than 6.1. + CVECFLAGS_VER := -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp + else + ifeq ($(GCC_OT_9_1_0),yes) # gcc versions 6.1 or newer, but older than 9.1. + CVECFLAGS_VER := -march=znver1 -mno-avx256-split-unaligned-store + else # gcc versions 9.1 or newer. + CVECFLAGS_VER := -march=znver2 + endif + endif +else +ifeq ($(CC_VENDOR),clang) + ifeq ($(CLANG_OT_9_0_0),yes) # clang versions older than 9.0. + CVECFLAGS_VER := -march=znver1 + else # clang versions 9.0 or newer. + CVECFLAGS_VER := -march=znver2 + endif +else +ifeq ($(CC_VENDOR),aocc) + ifeq ($(AOCC_OT_2_0_0),yes) # aocc versions older than 2.0. + CVECFLAGS_VER := -march=znver1 -mllvm -disable-licm-vrp + else # aocc versions 2.0 or newer. + CVECFLAGS_VER := -march=znver2 + endif +else + $(error gcc, clang, or aocc is required for this configuration.) +endif +endif +endif +CKVECFLAGS += $(CVECFLAGS_VER) +CRVECFLAGS += $(CVECFLAGS_VER) + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/zen2/make_defs.mk.old b/config/zen2/make_defs.mk.old new file mode 100644 index 0000000000..9f0370376c --- /dev/null +++ b/config/zen2/make_defs.mk.old @@ -0,0 +1,94 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# FLAGS that are specific to the 'zen2' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen2 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# Include file containing common flags for all AMD architectures. +AMD_CONFIG_FILE := amd_config.mk +AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen +-include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + +ifeq ($(CC_VENDOR),gcc) + ifeq ($(GCC_OT_9_1_0),yes) + ifeq ($(GCC_OT_6_1_0),yes) + # If gcc is older than 6.1.0, we must use -march=bdver4 and then remove the + # Bulldozer instruction sets that were omitted from Zen. + CRVECFLAGS += -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp + CKVECFLAGS += -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp + else + # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + endif + else + # If gcc is at least 9.1.0, then we can specify the microarchitecture using + # the preferred option. + CRVECFLAGS += -march=znver2 + CKVECFLAGS += -march=znver2 + endif + else + ifeq ($(CC_VENDOR),clang) + ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) + CKVECFLAGS += -march=znver2 + else + #if compiling with clang + VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) + CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) + #clang 9.0 or later: + ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) + CKVECFLAGS += -march=znver2 + else + CKVECFLAGS += -march=znver1 + endif # ge 9 + endif # AOCC 2 + endif # Clang +endif # gcc + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c new file mode 100644 index 0000000000..0336ddc953 --- /dev/null +++ b/config/zen3/bli_cntx_init_zen3.c @@ -0,0 +1,301 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_cntx_init_zen3( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen3_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 8, + + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + + cntx + ); + +#if 0 + // AMD: This will be enabled in other PRs. + // packm kernels + bli_cntx_set_packm_kers + ( + 2, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_8xk_gen_zen, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_6xk_gen_zen, + cntx + ); +#else + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + cntx + ); +#endif + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 4, + + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 16, + + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + + //swap + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + //copy + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // + // These are reference block sizes and may be overridden based on + // number of threads used at runtime. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + +// ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + +#if 0 + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 2, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); +#endif + +#if 0 + // AMD: This should be enabled in the PR which has added these kernels + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 28, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); +#else + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 16, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, + cntx + ); + +#endif + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} + diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h new file mode 100644 index 0000000000..661313ca94 --- /dev/null +++ b/config/zen3/bli_family_zen3.h @@ -0,0 +1,95 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_FAMILY_ZEN3_ +#define BLI_FAMILY_ZEN3_ + +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not paralleized. +// + +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 + + +// To enable framework optimizations for zen3 platform +// All zen3 specific code should be included in this macro +#define BLIS_CONFIG_ZEN3 + +// To enable framework optimizations for zen3 platform +// All zen3 specific code should be included in this macro +#define BLIS_CONFIG_ZEN3 + +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM + + +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_THRES_TRSM 32768 //128(128+128) => m*(m+n) +#define BLIS_SMALL_MATRIX_A_THRES_TRSM 128 + +#define BLIS_SMALL_MATRIX_A_THRES_M_GEMMT 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_GEMMT 128 + +#define BLIS_ENABLE_SMALL_MATRIX_ROME +#define BLIS_SMALL_MATRIX_THRES_ROME 400 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME 80 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M 40 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M 1000 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N 10 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME 150 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M 5 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N 130 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME 120 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M 10 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N 1200 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M 30 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N 280 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N 100 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME 110 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N 30 + +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME 120 +#define D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME_COL_PANEL_N 50 + +#endif diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk new file mode 100644 index 0000000000..5c68855db6 --- /dev/null +++ b/config/zen3/make_defs.mk @@ -0,0 +1,113 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen3 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized and reference kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +CROPTFLAGS := $(CKOPTFLAGS) +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +ifeq ($(CC_VENDOR),gcc) + ifeq ($(GCC_OT_9_1_0),yes) # gcc versions older than 9.1. + CVECFLAGS_VER := -march=znver1 -mno-avx256-split-unaligned-store + else + ifeq ($(GCC_OT_10_1_0),yes) # gcc versions 9.1 or newer, but older than 10.1. + CVECFLAGS_VER := -march=znver2 + else # gcc versions 10.1 or newer. + CVECFLAGS_VER := -march=znver3 + endif + endif +else +ifeq ($(CC_VENDOR),clang) + ifeq ($(CLANG_OT_9_0_0),yes) # clang versions older than 9.0. + CVECFLAGS_VER := -march=znver1 + else + ifeq ($(CLANG_OT_12_0_0),yes) # clang versions 9.0 or newer, but older than 12.0. + CVECFLAGS_VER := -march=znver2 + else # clang versions 12.0 or newer. + CVECFLAGS_VER := -march=znver3 + endif + endif +else +ifeq ($(CC_VENDOR),aocc) + ifeq ($(AOCC_OT_2_0_0),yes) # aocc versions older than 2.0. + CVECFLAGS_VER := -march=znver1 + else + ifeq ($(AOCC_OT_3_0_0),yes) # aocc versions 2.0 or newer, but older than 3.0. + CVECFLAGS_VER := -march=znver2 + else # aocc versions 3.0 or newer. + CVECFLAGS_VER := -march=znver3 + endif + endif +else + $(error gcc, clang, or aocc is required for this configuration.) +endif +endif +endif +CKVECFLAGS += $(CVECFLAGS_VER) +CRVECFLAGS += $(CVECFLAGS_VER) + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/zen3/make_defs.mk.old b/config/zen3/make_defs.mk.old new file mode 100644 index 0000000000..e0794ab0c7 --- /dev/null +++ b/config/zen3/make_defs.mk.old @@ -0,0 +1,137 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# FLAGS that are specific to the 'zen3' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen3 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +#frame pointers are needed to execution tracing +ifeq ($(ETRACE_ENABLE),1) +COPTFLAGS := -O3 +else +COPTFLAGS := -O3 -fomit-frame-pointer +endif +endif + + +# +# --- Enable ETRACE across the library if enabled ETRACE_ENABLE=[0,1] ----------------------- +# + +ifeq ($(ETRACE_ENABLE),1) +CDBGFLAGS += -pg -finstrument-functions -DAOCL_DTL_AUTO_TRACE_ENABLE +LDFLAGS += -ldl +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) +#gcc or clang version must be atleast 4.0 +# gcc 9.0 or later: +ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) +CKVECFLAGS += -march=znver2 +else +# If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 +# as the fallback option. +CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +endif +else +ifeq ($(CC_VENDOR),clang) + +# AOCC clang has various formats for the version line + +# AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) +# AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) +# AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) +# AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) +# AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + +# For our prupose we just want to know if it version 2x or 3x + +# for version 3x we will enable znver3 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) +CKVECFLAGS += -march=znver3 +else +# for version 2x we will enable znver2 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) +CKVECFLAGS += -march=znver2 +else +#if compiling with clang +VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) +CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) +#clang 9.0 or later: +ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) +CKVECFLAGS += -march=znver2 +else +CKVECFLAGS += -march=znver1 +endif # ge 9 +endif # aocc 2 +endif # aocc 3 +endif # clang +endif # gcc + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +CRVECFLAGS := $(CKVECFLAGS) + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config_registry b/config_registry index 46cb689bd7..f8248a3c74 100644 --- a/config_registry +++ b/config_registry @@ -4,19 +4,21 @@ # Please refer to the BLIS wiki on configurations for information on the # syntax and semantics of this file [1]. # -# [1] https://github.com/flame/blis/wiki/ConfigurationHowTo +# [1] https://github.com/flame/blis/blob/master/docs/ConfigurationHowTo.md # # Processor families. -x86_64_no_skx: intel64_no_skx amd64 -x86_64: intel64 amd64 -intel64: skx knl haswell sandybridge penryn generic -intel64_no_skx: haswell sandybridge penryn generic -amd64: zen excavator steamroller piledriver bulldozer generic -# NOTE: ARM families will remain disabled until runtime hardware detection -# logic is added to BLIS. -#arm64: cortexa57 generic -#arm32: cortexa15 cortexa9 generic +x86_64_no_skx: intel64_no_skx amd64_legacy # gcc without SKL-X also doesn't support Zen 1 +x86_64_no_zen2: intel64_no_knl zen amd64_legacy +x86_64_no_zen3: intel64 zen zen2 amd64_legacy +x86_64: intel64 amd64 amd64_legacy +intel64: skx knl haswell sandybridge penryn generic +intel64_no_knl: skx haswell sandybridge penryn generic +intel64_no_skx: haswell sandybridge penryn generic +amd64_legacy: excavator steamroller piledriver bulldozer generic +amd64: zen3 zen2 zen generic +arm64: armsve firestorm thunderx2 cortexa57 cortexa53 generic +arm32: cortexa15 cortexa9 generic # Intel architectures. skx: skx/skx/haswell/zen @@ -26,6 +28,8 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. +zen3: zen3/zen3/zen2/zen/haswell +zen2: zen2/zen2/zen/haswell zen: zen/zen/haswell excavator: excavator/piledriver steamroller: steamroller/piledriver @@ -33,6 +37,9 @@ piledriver: piledriver bulldozer: bulldozer # ARM architectures. +armsve: armsve/armsve +a64fx: a64fx/armsve +firestorm: firestorm/armv8a thunderx2: thunderx2/armv8a cortexa57: cortexa57/armv8a cortexa53: cortexa53/armv8a @@ -40,7 +47,8 @@ cortexa15: cortexa15/armv7a cortexa9: cortexa9/armv7a # IBM architectures. -power9: power9/generic +power10: power10 +power9: power9 bgq: bgq # Generic architectures. diff --git a/configure b/configure index 755cf61fcb..f64aac7055 100755 --- a/configure +++ b/configure @@ -1,11 +1,11 @@ #!/usr/bin/env bash # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2018, Advanced Micro Devices, Inc. +# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -51,8 +51,6 @@ print_usage() #echo " " #echo " BLIS ${version}" echo " " - echo " Field G. Van Zee" - echo " " echo " Configure BLIS's build system for compilation using a specified" echo " configuration directory." echo " " @@ -72,30 +70,37 @@ print_usage() echo " " echo " -p PREFIX, --prefix=PREFIX" echo " " - echo " The path to which make will install all build products." - echo " If given, this option implies the following options:" - echo " --libdir=PREFIX/lib" - echo " --incdir=PREFIX/include" + echo " The common installation prefix for all files. If given," + echo " this option effectively implies:" + echo " --libdir=EXECPREFIX/lib" + echo " --includedir=PREFIX/include" echo " --sharedir=PREFIX/share" - echo " If not given, PREFIX defaults to \$(HOME)/blis. If PREFIX" + echo " where EXECPREFIX defaults to PREFIX. If this option is" + echo " not given, PREFIX defaults to '${prefix_def}'. If PREFIX" echo " refers to a directory that does not exist, it will be" echo " created." echo " " + echo " --exec-prefix=EXECPREFIX" + echo " " + echo " The installation prefix for libraries. Specifically, if" + echo " given, this option effectively implies:" + echo " --libdir=EXECPREFIX/lib" + echo " If not given, EXECPREFIX defaults to PREFIX, which may be" + echo " modified by the --prefix option. If EXECPREFIX refers to" + echo " a directory that does not exist, it will be created." + echo " " echo " --libdir=LIBDIR" echo " " - echo " The path to which make will install libraries. If given," - echo " LIBDIR will override the corresponding directory implied" - echo " by --prefix; if not not given, LIBDIR defaults to" - echo " PREFIX/lib. If LIBDIR refers to a directory that does" - echo " not exist, it will be created." + echo " The path to which make will install libraries. If not" + echo " given, LIBDIR defaults to PREFIX/lib. If LIBDIR refers to" + echo " a directory that does not exist, it will be created." echo " " echo " --includedir=INCDIR" echo " " echo " The path to which make will install development header" - echo " files. If given, INCDIR will override the corresponding" - echo " directory implied by --prefix; if not given, INCDIR" - echo " defaults to PREFIX/include. If INCDIR refers to a" - echo " directory that does not exist, it will be created." + echo " files. If not given, INCDIR defaults to PREFIX/include." + echo " If INCDIR refers to a directory that does not exist, it" + echo " will be created." echo " " echo " --sharedir=SHAREDIR" echo " " @@ -104,18 +109,9 @@ print_usage() echo " and LDFLAGS). These files allow certain BLIS makefiles," echo " such as those in the examples or testsuite directories, to" echo " operate on an installed copy of BLIS rather than a local" - echo " (and possibly uninstalled) copy. If given, SHAREDIR will" - echo " override the corresponding directory implied by --prefix;" - echo " if not given, SHAREDIR defaults to PREFIX/share. If" - echo " SHAREDIR refers to a directory that does not exist, it" - echo " will be created." - echo " " - echo " -d DEBUG, --enable-debug[=DEBUG]" - echo " " - echo " Enable debugging symbols in the library. If argument" - echo " DEBUG is given as 'opt', then optimization flags are" - echo " kept in the framework, otherwise optimization is" - echo " turned off." + echo " (and possibly uninstalled) copy. If not given, SHAREDIR" + echo " defaults to PREFIX/share. If SHAREDIR refers to a" + echo " directory that does not exist, it will be created." echo " " echo " --enable-verbose-make, --disable-verbose-make" echo " " @@ -129,6 +125,13 @@ print_usage() echo " even if the command plus command line arguments exceeds" echo " the operating system limit (ARG_MAX)." echo " " + echo " -d DEBUG, --enable-debug[=DEBUG]" + echo " " + echo " Enable debugging symbols in the library. If argument" + echo " DEBUG is given as 'opt', then optimization flags are" + echo " kept in the framework, otherwise optimization is" + echo " turned off." + echo " " echo " --disable-static, --enable-static" echo " " echo " Disable (enabled by default) building BLIS as a static" @@ -141,6 +144,29 @@ print_usage() echo " library. If the shared library build is disabled, the" echo " static library build must remain enabled." echo " " + echo " --enable-rpath, --disable-rpath" + echo " " + echo " Enable (disabled by default) setting an install_name for" + echo " dynamic libraries on macOS which starts with @rpath rather" + echo " than the absolute install path." + echo " " + echo " -e SYMBOLS, --export-shared[=SYMBOLS]" + echo " " + echo " Specify the subset of library symbols that are exported" + echo " within a shared library. Valid values for SYMBOLS are:" + echo " 'public' (the default) and 'all'. By default, only" + echo " functions and variables that belong to public APIs are" + echo " exported in shared libraries. However, the user may" + echo " instead export all symbols in BLIS, even those that were" + echo " intended for internal use only. Note that the public APIs" + echo " encompass all functions that almost any user would ever" + echo " want to call, including the BLAS/CBLAS compatibility APIs" + echo " as well as the basic and expert interfaces to the typed" + echo " and object APIs that are unique to BLIS. Also note that" + echo " changing this option to 'all' will have no effect in some" + echo " environments, such as when compiling with clang on" + echo " Windows." + echo " " echo " -t MODEL, --enable-threading[=MODEL], --disable-threading" echo " " echo " Enable threading in the library, using threading model" @@ -148,6 +174,18 @@ print_usage() echo " --disable-threading is specified, threading will be" echo " disabled. The default is 'no'." echo " " + echo " --enable-system, --disable-system" + echo " " + echo " Enable conventional operating system support, such as" + echo " pthreads for thread-safety. The default state is enabled." + echo " However, in rare circumstances you may wish to configure" + echo " BLIS for use with a minimal or nonexistent operating" + echo " system (e.g. hardware simulators). In these situations," + echo " --disable-system may be used to jettison all compile-time" + echo " and link-time dependencies outside of the standard C" + echo " library. When disabled, this option also forces the use" + echo " of --disable-threading." + echo " " echo " --disable-pba-pools, --enable-pba-pools" echo " --disable-sba-pools, --enable-sba-pools" echo " " @@ -222,6 +260,40 @@ print_usage() echo " only be enabled when mixed domain/precision support is" echo " enabled." echo " " + echo " --disable-sup-handling, --enable-sup-handling" + echo " " + echo " Disable (enabled by default) handling of small/skinny" + echo " matrix problems via separate code branches. When disabled," + echo " these small/skinny level-3 operations will be performed by" + echo " the conventional implementation, which is optimized for" + echo " medium and large problems. Note that what qualifies as" + echo " \"small\" depends on thresholds that may vary by sub-" + echo " configuration." + echo " " + echo " --enable-amd-frame-tweaks, --disable-amd-frame-tweaks" + echo " " + echo " Enable building with certain framework files that have" + echo " been customized by AMD for Zen-based microarchitectures." + echo " The default counterparts of these files must be portable," + echo " and so these customized files may provide some (typically" + echo " modest) performance improvement for some select operations" + echo " and/or APIs, though there may a few (tiny dimension) cases" + echo " where the improvement is more pronounced. Note that the" + echo " target configuration must be Zen-based (or 'amd64') for" + echo " this option to have any effect. (Also note that this" + echo " option is NOT to be confused with enabling AMD *kernels*," + echo " which are determined by the BLIS subconfiguration used at" + echo " runtime.) By default, these customized files are disabled." + echo " " + echo " -a NAME --enable-addon=NAME" + echo " " + echo " Enable the code provided by an addon. An addon consists" + echo " of a separate directory of code that provides additional" + echo " APIs, implementations, and/or operations that would" + echo " otherwise not be present within a build of BLIS. This" + echo " option may be used multiple times to specify the inclusion" + echo " of multiple addons. By default, no addons are enabled." + echo " " echo " -s NAME --enable-sandbox=NAME" echo " " echo " Enable a separate sandbox implementation of gemm. This" @@ -256,6 +328,20 @@ print_usage() echo " which may be ignored in select situations if the" echo " implementation has a good reason to do so." echo " " + echo " --disable-trsm-preinversion, --enable-trsm-preinversion" + echo " " + echo " Disable (enabled by default) pre-inversion of triangular" + echo " matrix diagonals when performing trsm. When pre-inversion" + echo " is enabled, diagonal elements are inverted outside of the" + echo " microkernel (e.g. during packing) so that the microkernel" + echo " can use multiply instructions. When disabled, division" + echo " instructions are used within the microkernel. Executing" + echo " these division instructions within the microkernel will" + echo " incur a performance penalty, but numerical robustness will" + echo " improve for certain cases involving denormal numbers that" + echo " would otherwise result in overflow in the pre-inverted" + echo " values." + echo " " echo " --force-version=STRING" echo " " echo " Force configure to use an arbitrary version string" @@ -270,6 +356,15 @@ print_usage() echo " a sanity check to make sure these lists are constituted" echo " as expected." echo " " + echo " --complex-return=gnu|intel" + echo " " + echo " Specify the way in which complex numbers are returned" + echo " from Fortran functions, either \"gnu\" (return in" + echo " registers) or \"intel\" (return via hidden argument)." + echo " If not specified and the environment variable FC is set," + echo " attempt to determine the return type from the compiler." + echo " Otherwise, the default is \"gnu\"." + echo " " echo " -q, --quiet Suppress informational output. By default, configure" echo " is verbose. (NOTE: -q is not yet implemented)" echo " " @@ -278,15 +373,20 @@ print_usage() echo " Environment Variables:" echo " " echo " CC Specifies the C compiler to use." - echo " RANLIB Specifies the ranlib executable to use." - echo " AR Specifies the archiver to use." + echo " CXX Specifies the C++ compiler to use (sandbox only)." + echo " FC Specifies the Fortran compiler to use (only to determine --complex-return)." + echo " AR Specifies the static library archiver to use." + echo " RANLIB Specifies the ranlib (library indexer) executable to use." + echo " PYTHON Specifies the python interpreter to use." echo " CFLAGS Specifies additional compiler flags to use (prepended)." echo " LDFLAGS Specifies additional linker flags to use (prepended)." echo " LIBPTHREAD Pthreads library to use." - echo " PYTHON Specifies the python interpreter to use." echo " " - echo " Environment variables may also be specified as command line" - echo " options, e.g.:" + echo " Environment variables are traditionally set prior to running configure:" + echo " " + echo " CC=gcc ./configure [options] haswell" + echo " " + echo " However, they may also be specified as command line options, e.g.:" echo " " echo " ./configure [options] CC=gcc haswell" echo " " @@ -336,10 +436,10 @@ assign_key_value() # # found in a blacklist. # # # Note: $2 can actually be a list of items. -# dlist=\$"$1" -# ditem=\$"$2" +# ditem=\$"$1" +# dlist=\$"$2" # -# # Acquire the contents of $list and $item and store them in list_c +# # Acquire the contents of $dlist and $ditem and store them in list_c # # and item_c, respectively. # list_c=$(eval "expr \"$dlist\" ") # item_c=$(eval "expr \"$ditem\" ") @@ -356,7 +456,7 @@ assign_key_value() # done # # # Update the argument. -# eval "$1=\"${list_c}\"" +# eval "$2=\"${list_c}\"" #} pass_config_kernel_registries() @@ -619,13 +719,21 @@ read_registry_file() if [ "${mem}" != "${mems_mem}" ]; then #clist="${config_registry[$config]}" - clist=$(query_array "config_registry" ${config}) + clisttmp=$(query_array "config_registry" ${config}) # Replace the current config with its constituent config set, # canonicalize whitespace, and then remove duplicate config # set names, if they exist. Finally, update the config registry # with the new config list. - newclist=$(echo -e "${clist}" | sed -e "s/${mem}/${mems_mem}/g") + # NOTE: WE must use substitute_words() rather than a simple sed + # expression because we need to avoid matching partial strings. + # For example, if clist above contains "foo bar barsk" and we use + # sed to substitute "bee boo" as the members of "bar", the + # result would (incorrectly) be "foo bee boo bee boosk", + # which would then get reduced, via rm_duplicate_words(), to + # "foo bee boo boosk". + #newclist=$(echo -e "${clist}" | sed -e "s/${mem}/${mems_mem}/g") + newclist=$(substitute_words "${mem}" "${mems_mem}" "${clisttmp}") newclist=$(canonicalize_ws "${newclist}") newclist=$(rm_duplicate_words "${newclist}") @@ -708,7 +816,15 @@ read_registry_file() # canonicalize whitespace, and then remove duplicate kernel # set names, if they exist. Finally, update the kernel registry # with the new kernel list. - newklist=$(echo -e "${klisttmp}" | sed -e "s/${ker}/${kers_ker}/g") + # NOTE: WE must use substitute_words() rather than a simple sed + # expression because we need to avoid matching partial strings. + # For example, if klist above contains "foo bar barsk" and we use + # sed to substitute "bee boo" as the members of "bar", the + # result would (incorrectly) be "foo bee boo bee boosk", + # which would then get reduced, via rm_duplicate_words(), to + # "foo bee boo boosk". + #newklist=$(echo -e "${klisttmp}" | sed -e "s/${ker}/${kers_ker}/g") + newklist=$(substitute_words "${ker}" "${kers_ker}" "${klisttmp}") newklist=$(canonicalize_ws "${newklist}") newklist=$(rm_duplicate_words "${newklist}") @@ -730,6 +846,26 @@ read_registry_file() done } +substitute_words() +{ + local word new_words list newlist + + word="$1" + new_words="$2" + list="$3" + + for str in ${list}; do + + if [ "${str}" == "${word}" ]; then + newlist="${newlist} ${new_words}" + else + newlist="${newlist} ${str}" + fi + done + + echo "${newlist}" +} + build_kconfig_registry() { local familyname clist config kernels kernel cur_configs newvalue @@ -764,7 +900,7 @@ build_kconfig_registry() assign_key_value "kconfig_registry" "${kernel}" "${newvalue}" done - + done } @@ -864,6 +1000,18 @@ canonicalize_ws() echo "${str}" } +rm_duplicate_words_simple() +{ + local str revstr revres res + + str="$1" + + # Remote duplicates, keeping the first occurrence. + res=$(echo "${str}" | awk '{for (i=1;i<=NF;i++) if (!a[$i]++) printf("%s%s",$i,FS)}{printf("\n")}') + + echo "${res}" +} + rm_duplicate_words() { local str revstr revres res @@ -919,47 +1067,36 @@ get_cxx_search_list() echo "${list}" } -select_tool() +get_fc_search_list() { - local search_list CC_env the_cc cc + local list - # This is the list of compilers/tools to search for, and the order in - # which to search for them. - search_list=$1 + list="gfortran ifort" - # The environment variable associated with the compiler/tool type we - # are searching (e.g. CC, CXX, PYTHON). - CC_env=$2 + echo "${list}" +} - # If CC_env contains something, add it to the beginning of our default - # search list. - if [ -n "${CC_env}" ]; then - search_list="${CC_env} ${search_list}" - fi +get_ar_search_list() +{ + local list - # Initialize our selected compiler/tool to empty. - the_cc="" + list="ar" - # Try each compiler/tool in the list and select the first one we find that - # works. - for cc in ${search_list}; do + echo "${list}" +} - # See if the current compiler/tool works and/or is present. - ${cc} --version > /dev/null 2>&1 +get_ranlib_search_list() +{ + local list - if [ "$?" == 0 ]; then - the_cc=${cc} - break - fi - done + list="ranlib" - # Return the selected compiler/tool. - echo "${the_cc}" + echo "${list}" } auto_detect() { - local cc cflags config_defines detected_config rval + local cc cflags config_defines detected_config rval cmd # Use the same compiler that was found earlier. cc="${found_cc}" @@ -976,68 +1113,100 @@ auto_detect() cflags= fi - # Locate our source files. - bli_arch_c="bli_arch.c" - bli_cpuid_c="bli_cpuid.c" - main_c="config_detect.c" + # Accumulate a list of source files we'll need to compile along with + # the top-level (root) directory in which they are located. + c_src_pairs="" + c_src_pairs="${c_src_pairs} frame:bli_arch.c" + c_src_pairs="${c_src_pairs} frame:bli_cpuid.c" + c_src_pairs="${c_src_pairs} frame:bli_env.c" + c_src_pairs="${c_src_pairs} build:config_detect.c" - bli_arch_c_filepath=$(find ${dist_path}/frame -name "${bli_arch_c}") - bli_cpuid_c_filepath=$(find ${dist_path}/frame -name "${bli_cpuid_c}") - main_c_filepath=$(find ${dist_path}/build -name "${main_c}") - - # Locate headers needed directly by the above files. - bli_arch_h="bli_arch.h" - bli_cpuid_h="bli_cpuid.h" - bli_typed_h="bli_type_defs.h" + # Accumulate a list of full filepaths to the source files listed above. + c_src_filepaths="" + for pair in ${c_src_pairs}; do - bli_arch_h_filepath=$(find ${dist_path}/frame -name "${bli_arch_h}") - bli_cpuid_h_filepath=$(find ${dist_path}/frame -name "${bli_cpuid_h}") - bli_typed_h_filepath=$(find ${dist_path}/frame -name "${bli_typed_h}") + filename=${pair#*:} + rootdir=${pair%:*} - bli_arch_h_path=${bli_arch_h_filepath%/${bli_arch_h}} - bli_cpuid_h_path=${bli_cpuid_h_filepath%/${bli_cpuid_h}} - bli_typed_h_path=${bli_typed_h_filepath%/${bli_typed_h}} + filepath=$(find ${dist_path}/${rootdir} -name "${filename}") + c_src_filepaths="${c_src_filepaths} ${filepath}" + done - # Locate other headers needed by bli_type_defs.h. - bli_pthread_h="bli_pthread.h" - bli_pthread_h_filepath=$(find ${dist_path}/frame -name "${bli_pthread_h}") - bli_pthread_h_path=${bli_pthread_h_filepath%/${bli_pthread_h}} - bli_malloc_h="bli_malloc.h" - bli_malloc_h_filepath=$(find ${dist_path}/frame -name "${bli_malloc_h}") - bli_malloc_h_path=${bli_malloc_h_filepath%/${bli_malloc_h}} + # Accumulate a list of header files we'll need to locate along with + # the top-level (root) directory in which they are located. + c_hdr_pairs="" + c_hdr_pairs="${c_hdr_pairs} frame:bli_system.h" + c_hdr_pairs="${c_hdr_pairs} frame:bli_type_defs.h" + c_hdr_pairs="${c_hdr_pairs} frame:bli_arch.h" + c_hdr_pairs="${c_hdr_pairs} frame:bli_cpuid.h" + c_hdr_pairs="${c_hdr_pairs} frame:bli_env.h" + # NOTE: These headers are needed by bli_type_defs.h. + c_hdr_pairs="${c_hdr_pairs} frame:bli_malloc.h" + c_hdr_pairs="${c_hdr_pairs} frame:bli_pthread.h" + + # Accumulate a list of full paths to the header files listed above. + # While we are at it, we include the "-I" compiler option to indicate + # adding the path to the list of directories to search when encountering + # #include directives. + c_hdr_paths="" + for pair in ${c_hdr_pairs}; do + + filename=${pair#*:} + rootdir=${pair%:*} + + filepath=$(find ${dist_path}/${rootdir} -name "${filename}") + path=${filepath%/*} + c_hdr_paths="${c_hdr_paths} -I${path}" + done # Define the executable name. autodetect_x="auto-detect.x" # Create #defines for all of the BLIS_CONFIG_ macros in bli_cpuid.c. + bli_cpuid_c_filepath=$(find ${dist_path}/frame -name "bli_cpuid.c") config_defines=$(grep BLIS_CONFIG_ ${bli_cpuid_c_filepath} \ | sed -e 's/#ifdef /-D/g') - # Set the linker flags. We need pthreads because it is needed for - # parts of bli_arch.c unrelated to bli_arch_string(), which is called - # by the main() function in ${main_c}. - if [ $is_win = no ]; then + # Set the linker flags. We typically need pthreads (or BLIS's homerolled + # equiavlent) because it is needed for parts of bli_arch.c unrelated to + # bli_arch_string(), which is called by the main() function in ${main_c}. + if [[ "$is_win" == "no" || "$cc_vendor" != "clang" ]]; then ldflags="${LIBPTHREAD--lpthread}" fi + # However, if --disable-system was given, we override the choice made above + # and do not use any pthread link flags. + if [[ "$enable_system" == "no" ]]; then + ldflags= + fi + # Compile the auto-detect program using source code inside the # framework. # NOTE: -D_GNU_SOURCE is needed to enable POSIX extensions to # pthreads (i.e., barriers). - ${cc} ${config_defines} \ + + cmd="${cc} ${config_defines} \ -DBLIS_CONFIGURETIME_CPUID \ - -I${bli_cpuid_h_path} \ - -I${bli_arch_h_path} \ - -I${bli_typed_h_path} \ - -I${bli_pthread_h_path} \ - -I${bli_malloc_h_path} \ + ${c_hdr_paths} \ -std=c99 -D_GNU_SOURCE \ ${cflags} \ - ${bli_arch_c_filepath} \ - ${bli_cpuid_c_filepath} \ + ${c_src_filepaths} \ ${ldflags} \ - ${main_c_filepath} \ - -o ${autodetect_x} + -o ${autodetect_x}" + + if [ "${debug_auto_detect}" == "no" ]; then + + # Execute the compilation command. + eval ${cmd} + + else + + # Debugging stuff. Instead of executing ${cmd}, join the lines together + # with tr and trim excess whitespace via awk. + cmd=$(echo "${cmd}" | tr '\n' ' ' | awk '{$1=$1;print}') + echo "${cmd}" + return + fi # Run the auto-detect program. detected_config=$(./${autodetect_x}) @@ -1292,13 +1461,99 @@ get_compiler_version() # isolate the version number. # The last part ({ read first rest ; echo $first ; }) is a workaround # to OS X's egrep only returning the first match. - cc_vendor=$(echo "${vendor_string}" | egrep -o 'icc|gcc|clang|emcc|pnacl|IBM' | { read first rest ; echo $first ; }) + cc_vendor=$(echo "${vendor_string}" | egrep -o 'icc|gcc|clang|emcc|pnacl|IBM|oneAPI|crosstool-NG|GCC' | { read first rest ; echo $first ; }) + + # AOCC version strings contain both "clang" and "AOCC" substrings, and + # so we have perform a follow-up check to make sure cc_vendor gets set + # correctly. + aocc_grep=$(echo "${vendor_string}" | grep 'AOCC') + if [ -n "${aocc_grep}" ]; then + cc_vendor="aocc" + fi + + # Detect armclang, which doesn't have a nice, unambiguous, one-word tag + armclang_grep=$(echo "${vendor_string}" | grep 'Arm C/C++/Fortran Compiler') + if [ -n "${armclang_grep}" ]; then + cc_vendor="armclang" + fi + + # Begin parsing cc_vendor for the version string. + + if [ "${cc_vendor}" = "GCC" ]; then + # Conda gcc sometimes has GCC (all caps) in the version string + cc_vendor="gcc" + fi + if [ "${cc_vendor}" = "crosstool-NG" ]; then + # Treat compilers built by crosstool-NG (for eg: conda) as gcc. + cc_vendor="gcc" + fi if [ "${cc_vendor}" = "icc" -o \ - "${cc_vendor}" = "gcc" -o \ - "${cc_vendor}" = "clang" ]; then + "${cc_vendor}" = "gcc" ]; then + cc_version=$(${cc} -dumpversion) + + elif [ "${cc_vendor}" = "armclang" ]; then + + # Treat armclang as regular clang. + cc_vendor="clang" + cc_version=$(echo "${vendor_string}" \ + | egrep -o 'based on LLVM [0-9]+\.[0-9]+\.?[0-9]*' \ + | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*') + + elif [ "${cc_vendor}" = "clang" ]; then + + cc_version=$(echo "${vendor_string}" \ + | egrep -o '(clang|LLVM) version [0-9]+\.[0-9]+\.?[0-9]*' \ + | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*') + + elif [ "${cc_vendor}" = "aocc" ]; then + + aocc_ver21=$(echo "${vendor_string}" | grep 'AOCC.LLVM.2') + + # Versions 2.0 and 2.1 had different version string formats from + # 2.2 and later, so we have to handle them separately. + # Examples: + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + + if [ -n "${aocc_ver21}" ]; then + + # Grep for the AOCC.LLVM.x.y.z substring first, and then isolate the + # version number. Also, the string may contain multiple instances of + # the version number, so only use the first occurrence. + cc_version=$(echo "${vendor_string}" \ + | egrep -o 'AOCC.LLVM.[0-9]+\.[0-9]+\.?[0-9]*' \ + | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*' \ + | { read first rest ; echo $first ; }) + else + + # Grep for the AOCC_x.y.z substring first, and then isolate the + # version number. As of this writing, these version strings don't + # include multiple instances of the version, but we nonetheless + # take only the first occurrence as a future-oriented safety + # measure. + cc_version=$(echo "${vendor_string}" \ + | egrep -o 'AOCC_[0-9]+\.[0-9]+\.?[0-9]*' \ + | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*' \ + | { read first rest ; echo $first ; }) + fi + + elif [ "${cc_vendor}" = "oneAPI" ]; then + + # Treat Intel oneAPI's clang as clang, not icc. + cc_vendor="clang" + cc_version=$(echo "${vendor_string}" \ + | egrep -o '[0-9]+\.[0-9]+\.[0-9]+\.?[0-9]*' \ + | { read first rest ; echo ${first} ; }) + else - cc_version=$(echo "${vendor_string}" | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*' | { read first rest ; echo ${first} ; }) + + cc_version=$(echo "${vendor_string}" \ + | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*' \ + | { read first rest ; echo ${first} ; }) fi # Parse the version number into its major, minor, and revision @@ -1343,12 +1598,14 @@ check_compiler() # Specific: # # skx: icc 15.0.1+, gcc 6.0+, clang 3.9+ - # knl: icc 14.0.1+, gcc 5.0+, clang 3.5+ + # knl: icc 14.0.1+, gcc 5.0+, clang 3.9+ # haswell: any # sandybridge: any # penryn: any # # zen: gcc 6.0+[1], clang 4.0+ + # zen2: gcc 6.0+[1], clang 4.0+ + # zen3: gcc 6.0+[1], clang 4.0+ # excavator: gcc 4.9+, clang 3.5+ # steamroller: any # piledriver: any @@ -1358,6 +1615,8 @@ check_compiler() # cortexa15: any # cortexa9: any # + # armsve: clang11+, gcc10+ + # # generic: any # # Note: These compiler requirements were originally modeled after similar @@ -1373,6 +1632,8 @@ check_compiler() echo "${script_name}: checking for blacklisted configurations due to ${cc} ${cc_version}." + # Fixme: check on a64fx, neoverse, and others + # gcc if [ "x${cc_vendor}" = "xgcc" ]; then @@ -1400,6 +1661,11 @@ check_compiler() # Thus, this "blacklistcc_add" statement has been moved above. #blacklistcc_add "zen" blacklistcc_add "skx" + # gcc 5.x may support POWER9 but it is unverified. + blacklistcc_add "power9" + fi + if [ ${cc_major} -lt 10 ]; then + blacklistcc_add "armsve" fi fi @@ -1414,31 +1680,206 @@ check_compiler() blacklistcc_add "skx" fi fi + if [ ${cc_major} -eq 18 ]; then + echo "${script_name}: ${cc} ${cc_version} is known to cause erroneous results. See https://github.com/flame/blis/issues/371 for details." + blacklistcc_add "knl" + blacklistcc_add "skx" + fi + if [ ${cc_major} -ge 19 ]; then + echo "${script_name}: ${cc} ${cc_version} is known to cause erroneous results. See https://github.com/flame/blis/issues/371 for details." + echoerr_unsupportedcc + fi fi # clang if [ "x${cc_vendor}" = "xclang" ]; then - - if [ ${cc_major} -lt 3 ]; then - echoerr_unsupportedcc - fi - if [ ${cc_major} -eq 3 ]; then - if [ ${cc_minor} -lt 3 ]; then + if [ "$(echo ${vendor_string} | grep -o Apple)" = "Apple" ]; then + if [ ${cc_major} -lt 5 ]; then echoerr_unsupportedcc fi - if [ ${cc_minor} -lt 5 ]; then + # See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions + if [ ${cc_major} -eq 5 ]; then + # Apple clang 5.0 is clang 3.4svn blacklistcc_add "excavator" blacklistcc_add "zen" + fi + if [ ${cc_major} -lt 7 ]; then blacklistcc_add "knl" + blacklistcc_add "skx" fi + else + if [ ${cc_major} -lt 3 ]; then + echoerr_unsupportedcc + fi + if [ ${cc_major} -eq 3 ]; then + if [ ${cc_minor} -lt 3 ]; then + echoerr_unsupportedcc + fi + if [ ${cc_minor} -lt 5 ]; then + blacklistcc_add "excavator" + blacklistcc_add "zen" + fi + if [ ${cc_minor} -lt 9 ]; then + blacklistcc_add "knl" + blacklistcc_add "skx" + fi + fi + if [ ${cc_major} -lt 4 ]; then + # See comment above regarding zen support. + #blacklistcc_add "zen" + : # explicit no-op since bash can't handle empty loop bodies. + fi + if [ ${cc_major} -lt 11 ]; then + blacklistcc_add "armsve" + fi + fi + fi +} + +check_compiler_version_ranges() +{ + local cc + + cc="${found_cc}" + + # + # We check for various compiler version ranges that may cause us + # issues in properly supporting those compiler versions within the + # BLIS build system. + # + # range: gcc < 4.9.0 (ie: 4.8.5 or older) + # variable: gcc_older_than_4_9_0 + # comments: + # These older versions of gcc may support microarchitectures such as + # sandybridge, but the '-march=' flag uses a different label syntax. + # In newer versions, '-march=sandybridge' is the preferred syntax [1]. + # However, in older versions, the syntax for the same compiler option + # is '-march=corei7-avx' [2]. + # + # [1] https://gcc.gnu.org/onlinedocs/gcc-4.9.0/gcc/i386-and-x86-64-Options.html#i386-and-x86-64-Options + # [2] https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/i386-and-x86-64-Options.html#i386-and-x86-64-Options + # + # range: gcc < 6.1 (ie: 5.5 or older) + # variable: gcc_older_than_6_1_0 + # comments: + # These older versions of gcc do not explicitly support the Zen (Zen1) + # microarchitecture; the newest microarchitectural value understood by + # these versions is '-march=bdver4' [3]. However, basic support for these + # older versions can be attained in a roundabout way by starting with the + # instruction sets enabled by '-march=bdver4' and then disabling the + # instruction sets that were removed in the transition from Excavator to + # Zen, namely: FMA4, TBM, XOP, and LWP. Newer versions of gcc support Zen + # via the '-march=znver1' option [4]. + # + # [3] https://gcc.gnu.org/onlinedocs/gcc-5.5.0/gcc/x86-Options.html#x86-Options + # [4] https://gcc.gnu.org/onlinedocs/gcc-6.1.0/gcc/x86-Options.html#x86-Options + # + # range: gcc < 9.1 (ie: 8.3 or older) + # variable: gcc_older_than_9_1_0 + # comments: + # These older versions of gcc do not explicitly support the Zen2 + # microarchitecture; the newest microarchitectural value understood by + # these versions is either '-march=znver1' (if !gcc_older_than_6_1_0) [5] + # or '-march=bdver4' (if gcc_older_than_6_1_0) [3]. If gcc is 6.1 or + # newer, '-march=znver1' may be used (since the instruction sets it + # enables are a subset of those enabled by '-march=znver2'); otherwise, + # '-march=bdver4' must be used in conjuction with disabling the + # instruction sets that were removed in the transition from Excavator to + # Zen, as described in the section above for gcc_older_than_6_1_0. + # Newer versions of gcc support Zen2 via the '-march=znver2' option [6]. + # + # [5] https://gcc.gnu.org/onlinedocs/gcc-8.3.0/gcc/x86-Options.html#x86-Options + # [6] https://gcc.gnu.org/onlinedocs/gcc-9.4.0/gcc/x86-Options.html#x86-Options + # + # range: gcc < 10.1 (ie: 9.4 or older) + # variable: gcc_older_than_10_1_0 + # comments: + # These older versions of gcc do not explicitly support the Zen3 + # microarchitecture; the newest microarchitectural value understood by + # these versions is '-march=znver2' (if !gcc_older_than_9_1_0) [7]. + # Newer versions of gcc support Zen3 via the '-march=znver3' option [8]. + # + # [7] https://gcc.gnu.org/onlinedocs/gcc-9.4.0/gcc/x86-Options.html#x86-Options + # [8] https://gcc.gnu.org/onlinedocs/gcc-10.3.0/gcc/x86-Options.html#x86-Options + # + + gcc_older_than_4_9_0='no' + gcc_older_than_6_1_0='no' + gcc_older_than_9_1_0='no' + gcc_older_than_10_1_0='no' + + clang_older_than_9_0_0='no' + clang_older_than_12_0_0='no' + + aocc_older_than_2_0_0='no' + aocc_older_than_3_0_0='no' + + echo "${script_name}: checking ${cc} ${cc_version} against known consequential version ranges." + + # gcc + if [ "x${cc_vendor}" = "xgcc" ]; then + + # Check for gcc < 4.9.0 (ie: 4.8.5 or older). + if [ ${cc_major} -eq 4 ]; then if [ ${cc_minor} -lt 9 ]; then - blacklistcc_add "skx" + echo "${script_name}: note: found ${cc} version older than 4.9.0." + gcc_older_than_4_9_0='yes' fi fi - if [ ${cc_major} -lt 4 ]; then - # See comment above regarding zen support. - #blacklistcc_add "zen" - : # explicit no-op since bash can't handle empty loop bodies. + + # Check for gcc < 6.1.0 (ie: 5.5 or older). + if [ ${cc_major} -lt 6 ]; then + echo "${script_name}: note: found ${cc} version older than 6.1." + gcc_older_than_6_1_0='yes' + fi + + # Check for gcc < 9.1.0 (ie: 8.3 or older). + if [ ${cc_major} -lt 9 ]; then + echo "${script_name}: note: found ${cc} version older than 9.1." + gcc_older_than_9_1_0='yes' + fi + + # Check for gcc < 10.1.0 (ie: 9.4 or older). + if [ ${cc_major} -lt 10 ]; then + echo "${script_name}: note: found ${cc} version older than 10.1." + gcc_older_than_10_1_0='yes' + fi + fi + + # icc + if [ "x${cc_vendor}" = "xicc" ]; then + : + fi + + # clang + if [ "x${cc_vendor}" = "xclang" ]; then + + # Check for clang < 9.0.0. + if [ ${cc_major} -lt 9 ]; then + echo "${script_name}: note: found ${cc} version older than 9.0." + clang_older_than_9_0_0='yes' + fi + + # Check for clang < 12.0.0. + if [ ${cc_major} -lt 12 ]; then + echo "${script_name}: note: found ${cc} version older than 12.0." + clang_older_than_12_0_0='yes' + fi + fi + + # aocc + if [ "x${cc_vendor}" = "xaocc" ]; then + + # Check for aocc < 2.0.0. + if [ ${cc_major} -lt 2 ]; then + echo "${script_name}: note: found ${cc} version older than 2.0." + aocc_older_than_2_0_0='yes' + fi + + # Check for aocc < 3.0.0. + if [ ${cc_major} -lt 3 ]; then + echo "${script_name}: note: found ${cc} version older than 3.0." + aocc_older_than_3_0_0='yes' fi fi } @@ -1496,8 +1937,8 @@ check_assembler() # # The assembler on OS X won't recognize AVX-512 without help. - if [ "$(uname -s)" == "Darwin" ]; then - cflags="-Wa,-march=knl" + if [ "${cc_vendor}" == "clang" ]; then + cflags="-march=knl" fi asm_fp=$(find ${asm_dir} -name "avx512f.s") @@ -1513,8 +1954,8 @@ check_assembler() # # The assembler on OS X won't recognize AVX-512 without help. - if [ "$(uname -s)" == "Darwin" ]; then - cflags="-Wa,-march=skylake-avx512" + if [ "${cc_vendor}" == "clang" ]; then + cflags="-march=skylake-avx512" fi asm_fp=$(find ${asm_dir} -name "avx512dq.s") @@ -1620,6 +2061,223 @@ set_default_version() fi } +select_tool_w_env() +{ + local search_list env_var env_str tool_str found_var + local _the_tool + + # Example calling sequence: + # + # select_tool_w_env "${cc_search_list}" "${CC}" "CC" "C compiler" "yes" found_cc + # + + search_list="$1" # the tool's default search list. + env_var="$2" # the value of the environment variable for this tool. + env_str="$3" # a string naming the source of env_var. + tool_str="$4" # a human-readable string identifying the tool. + is_required="$5" # is it fatal if env_var doesn't exist/work? (yes or no) + found_var="$6" # the variable into which to save the selected tool. + + # If the environment variable contains something, verify that it exists. If + # it is unset or empty, we proceed with the default search list. + if [ -n "${env_var}" ]; then + + echo "${script_name}: user specified a ${tool_str} via ${env_str} (${env_var})." + + # See if the binary specified by env_var exists. + _the_tool=$(select_tool "${env_var}" "${env_str}") + + # Copy the result into the variable specified by found_var. + eval "${found_var}=\"${_the_tool}\"" + + # If the tool specified by env_var doesn't exist, throw a tantrum. + if [ -z "${_the_tool}" ]; then + + echo "${script_name}: *** Could not find the ${tool_str} specified via ${env_str} ('${env_var}')." + + # Whether the tantrum is fatal depends on the is_required argument. + if [ "${is_required}" == "yes" ]; then + echo "${script_name}: *** A working ${tool_str} is required. Please set ${env_str}" + echo "${script_name}: *** to a ${tool_str} that exists (or unset ${env_str})." + exit 1 + else + echo "${script_name}: *** Note that a ${tool_str} will not be available." + + # Set the found_var variable to *something* so that the output + # makefile fragment contains a record that the tool wasn't found. + eval "${found_var}=\"${env_str}\"-not-found" + fi + else + # The user-specified tool was found. + echo "${script_name}: ${_the_tool} exists and appears to work." + echo "${script_name}: using '${_the_tool}' as ${tool_str}." + fi + + else + + echo "${script_name}: ${tool_str} search list is: ${search_list}." + + # Search for a working tool from the search list. + _the_tool=$(select_tool "${search_list}" "${env_str}") + + # Copy the result into the variable specified by found_var. + eval "${found_var}=\"${_the_tool}\"" + + # If we didn't find a working tool from the search list, throw a tantrum. + if [ -z "${_the_tool}" ]; then + + echo "${script_name}: *** Could not find a ${tool_str} from the search list." + + # Whether the tantrum is fatal depends on the is_required argument. + if [ "${is_required}" == "yes" ]; then + echo "${script_name}: *** A working ${tool_str} is required. Cannot continue." + exit 1 + else + echo "${script_name}: *** Note that a ${tool_str} will not be available." + + # Set the found_var variable to *something* so that the output + # makefile fragment contains a record that the tool wasn't found. + eval "${found_var}=\"${env_str}-not-found\"" + fi + else + # A tool from the search list was found. + echo "${script_name}: found '${_the_tool}'." + echo "${script_name}: using '${_the_tool}' as ${tool_str}." + fi + fi +} + +select_tool() +{ + local search_list env_str + local the_tool tool the_flags rval + + # This is the list of tools to search for, and the order in which + # to search for them. + search_list="$1" + + # This is the name of the environment variable associated with the tool. For + # example, if search_list is a list of C compilers, env_str will be "CC". + env_str="$2" + + # Initialize our selected tool to empty. + the_tool="" + + # Try each tool in the list and select the first one we find that works. + for tool in ${search_list}; do + + # Map each tool (via its canonical environment variable form) to the set + # of options we should use to check that it is working and available. + the_flags=$(get_tool_checkflags "${env_str}") + + # Check that the tool works with at least one of the flags in the_flags + # the_flags (or, if the_flags is empty, check that the tool exists). + rval=$(check_tool "${tool}" "${the_flags}") + + # If check_tool() returns 0, we're done. + if [ "${rval}" == "0" ]; then + the_tool=${tool} + break + fi + done + + # Return the selected tool. + echo "${the_tool}" +} + +get_tool_checkflags() +{ + local env_str + local allflags flaglist + + # The tool for which we will determine the flag/option to pass in + # when testing that the tool works. Notice that it's not actually + # the tool but rather its equivalent environment variable. + env_str="${1}" + + # The default list of flags to use in most circumstances. + allflags="--version -V -h" + + if [ "${os_name}" = "Linux" ]; then + + # If we are on Linux, it is very likely that all the tools will respond + # to at least one of the usual flags. + flaglist="${allflags}" + + else + + # If we are on Darwin/OSX/BSD or something else, we sometimes skip flag + # checks. (Note that when the list of flags to check is empty, we end + # up testing for the existence of the tool instead.) + if [ "${env_str}" = "AR" -o \ + "${env_str}" = "RANLIB" ]; then + + # AR, RANLIB may not respond to the normal flags on Darwin/OSX/BSD, + # so all we can really do is check for their existence. + flaglist="" + else + # Even on Darwin/OSX/BSD, we expect that CC, CXX, FC, PYTHON will + # respond to the typical flag checklist. + flaglist="${allflags}" + fi + fi + + echo "${flaglist}" +} + +check_tool() +{ + local tool the_flags + local rval opt toolpath + + # This is the name, or filepath, of the tool to check for. + tool="$1" + + # Some command line options to try to determine that the tool works. + the_flags="$2" + + # Start with the assuming that the tool doesn't work/exist. + rval=1 + + if [ -n "${the_flags}" ]; then + + # If the list of flags to check non-empty, we will iterate through the + # list in search of a flag that works. Failure to find one that works + # means the tool doesn't work (or, if the user specified the tool via + # its environment variable, failure might mean that the tool doesn't + # even exist). + + # Try each flag in the list of flags. + for opt in ${the_flags}; do + + # See if the tool responds to the current flag. + ${tool} ${opt} > /dev/null 2>&1 + + # If the tool responded to the flag with a nominal error code of + # 0, we found one that works and set rval accoringly. + if [ "$?" == 0 ]; then + rval=0 + break + fi + done + else + + # If the list of flags to check is empty, we interpret this as a + # request to instead check for the existence of the tool. + + # Use 'which' to determine if the tool exists. + toolpath="$(which ${tool} 2> /dev/null)" + + # If the tool doesn't exist, we set rval accordingly. + if [ -n "${toolpath}" ]; then + rval=0 + fi + fi + + # Return the error code. + echo "${rval}" +} + # @@ -1674,6 +2332,13 @@ main() bli_config_h_in_path="${build_dirpath}/${bli_config_h_in}" bli_config_h_out_path="${cur_dirpath}/${bli_config_h_out}" + # The names/paths for the template bli_addon.h.in and its instantiated + # counterpart. + bli_addon_h_in='bli_addon.h.in' + bli_addon_h_out='bli_addon.h' + bli_addon_h_in_path="${build_dirpath}/${bli_addon_h_in}" + bli_addon_h_out_path="${cur_dirpath}/${bli_addon_h_out}" + # Path to 'mirror-tree.sh' script. mirror_tree_sh="${build_dirpath}/mirror-tree.sh" @@ -1697,6 +2362,10 @@ main() frame_dir='frame' frame_dirpath="${dist_path}/${frame_dir}" + # The names of the addons. + addon_dir='addon' + addon_dirpath="${dist_path}/${addon_dir}" + # The name of the sandbox directory. sandbox_dir='sandbox' sandbox_dirpath="${dist_path}/${sandbox_dir}" @@ -1731,21 +2400,33 @@ main() # -- configure options -- - # The user-given install prefix and a flag indicating it was given. - #install_prefix_def="${HOME}/blis" - install_prefix_user=${HOME}/blis # default to this directory. + # Define the default prefix so that the print_usage() function can + # output it in the --help text. + prefix_def='/usr/local' + + # The installation prefix, assigned its default value, and a flag to + # track whether or not it was given by the user. + prefix=${prefix_def} prefix_flag='' - # The user-given install libdir and a flag indicating it was given. - install_libdir_user='' + # The installation exec_prefix, assigned its default value, and a flag to + # track whether or not it was given by the user. + exec_prefix='${prefix}' + exec_prefix_flag='' + + # The installation libdir, assigned its default value, and a flag to + # track whether or not it was given by the user. + libdir='${exec_prefix}/lib' libdir_flag='' - # The user-given install includedir and a flag indicating it was given. - install_incdir_user='' - incdir_flag='' + # The installation includedir, assigned its default value, and a flag to + # track whether or not it was given by the user. + includedir='${prefix}/include' + includedir_flag='' - # The user-given install sharedir and a flag indicating it was given. - install_sharedir_user='' + # The installation sharedir, assigned its default value, and a flag to + # track whether or not it was given by the user. + sharedir='${prefix}/share' sharedir_flag='' # The preset value of CFLAGS and LDFLAGS (ie: compiler and linker flags @@ -1757,8 +2438,11 @@ main() debug_type='' debug_flag='' + # The system flag. + enable_system='yes' + # The threading flag. - threading_model='no' + threading_model='off' # The method of assigning micropanels to threads in the JR and JR loops. thread_part_jrir='slab' @@ -1772,6 +2456,8 @@ main() enable_arg_max_hack='no' enable_static='yes' enable_shared='yes' + enable_rpath='no' + export_shared='public' enable_pba_pools='yes' enable_sba_pools='yes' enable_mem_tracing='no' @@ -1781,8 +2467,16 @@ main() enable_cblas='no' enable_mixed_dt='yes' enable_mixed_dt_extra_mem='yes' + enable_sup_handling='yes' + enable_amd_frame_tweaks='no' enable_memkind='' # The default memkind value is determined later on. + enable_trsm_preinversion='yes' force_version='no' + complex_return='default' + + # The addon flag and names. + addon_flag='' + addon_list='' # The sandbox flag and name. sandbox_flag='' @@ -1813,6 +2507,13 @@ main() # source distribution directory. dummy_file='_blis_dir_detect.tmp' + # -- Debugging -- + + # A global flag to help debug the compilation command for the executable + # that configure builds on-the-fly to perform hardware auto-detection. + debug_auto_detect="no" + + # -- Command line option/argument parsing ---------------------------------- @@ -1821,7 +2522,7 @@ main() # Process our command line options. unset OPTIND - while getopts ":hp:d:s:t:r:qci:b:-:" opt; do + while getopts ":hp:d:e:a:s:t:r:qci:b:-:" opt; do case $opt in -) case "$OPTARG" in @@ -1833,19 +2534,23 @@ main() ;; prefix=*) prefix_flag=1 - install_prefix_user=${OPTARG#*=} + prefix=${OPTARG#*=} + ;; + exec-prefix=*) + exec_prefix_flag=1 + exec_prefix=${OPTARG#*=} ;; libdir=*) libdir_flag=1 - install_libdir_user=${OPTARG#*=} + libdir=${OPTARG#*=} ;; includedir=*) - incdir_flag=1 - install_incdir_user=${OPTARG#*=} + includedir_flag=1 + includedir=${OPTARG#*=} ;; sharedir=*) sharedir_flag=1 - install_sharedir_user=${OPTARG#*=} + sharedir=${OPTARG#*=} ;; enable-debug) debug_flag=1 @@ -1882,15 +2587,30 @@ main() disable-shared) enable_shared='no' ;; + enable-rpath) + enable_rpath='yes' + ;; + disable-rpath) + enable_rpath='no' + ;; + export-shared=*) + export_shared=${OPTARG#*=} + ;; + enable-system) + enable_system='yes' + ;; + disable-system) + enable_system='no' + ;; enable-threading=*) threading_model=${OPTARG#*=} ;; + disable-threading) + threading_model='off' + ;; thread-part-jrir=*) thread_part_jrir=${OPTARG#*=} ;; - disable-threading) - threading_model='no' - ;; enable-pba-pools) enable_pba_pools='yes' ;; @@ -1909,12 +2629,21 @@ main() disable-mem-tracing) enable_mem_tracing='no' ;; + enable-addon=*) + addon_flag=1 + addon_name=${OPTARG#*=} + # Append the addon name to the list. + addon_list="${addon_list} ${addon_name}" + ;; + disable-addon) + addon_flag='' + ;; enable-sandbox=*) sandbox_flag=1 sandbox=${OPTARG#*=} ;; disable-sandbox) - sandbox_flag=0 + sandbox_flag='' ;; int-size=*) int_type_size=${OPTARG#*=} @@ -1946,18 +2675,39 @@ main() disable-mixed-dt-extra-mem) enable_mixed_dt_extra_mem='no' ;; + enable-sup-handling) + enable_sup_handling='yes' + ;; + disable-sup-handling) + enable_sup_handling='no' + ;; + enable-amd-frame-tweaks) + enable_amd_frame_tweaks='yes' + ;; + disable-amd-frame-tweaks) + enable_amd_frame_tweaks='no' + ;; with-memkind) enable_memkind='yes' ;; without-memkind) enable_memkind='no' ;; + enable-trsm-preinversion) + enable_trsm_preinversion='yes' + ;; + disable-trsm-preinversion) + enable_trsm_preinversion='no' + ;; force-version=*) force_version=${OPTARG#*=} ;; show-config-list) show_config_list=1 ;; + complex-return=*) + complex_return=${OPTARG#*=} + ;; *) print_usage ;; @@ -1967,12 +2717,21 @@ main() ;; p) prefix_flag=1 - install_prefix_user=$OPTARG + prefix=$OPTARG ;; d) debug_flag=1 debug_type=$OPTARG ;; + e) + export_shared=$OPTARG + ;; + a) + addon_flag=1 + addon_name=$OPTARG + # Append the addon name to the list. + addon_list="${addon_list} ${addon_name}" + ;; s) sandbox_flag=1 sandbox=$OPTARG @@ -2040,24 +2799,13 @@ main() # -- Find a python interpreter --------------------------------------------- - # Acquire the python search order. This may vary based on the os found - # above. + # Acquire the default python search order. python_search_list=$(get_python_search_list) - echo "${script_name}: python interpeter search list is: ${python_search_list}." - - # Find a working python interpreter. - found_python=$(select_tool "${python_search_list}" "${PYTHON}") - - # If we didn't find any working python interpreters, we print an error - # message. - if [ -z "${found_python}" ]; then - echo "${script_name}: *** Could not find working python interperter! Cannot continue." - exit 1 - fi - - echo "${script_name}: using '${found_python}' python interpreter." - + # Select a python interpreter from the default list, or from PYTHON if it + # refers to a valid binary. + select_tool_w_env "${python_search_list}" "${PYTHON}" "PYTHON" \ + "python interpreter" "yes" found_python # -- Check the python version ---------------------------------------------- @@ -2068,44 +2816,19 @@ main() # -- Find a C compiler ----------------------------------------------------- - # Acquire the compiler search order. This will vary based on the os found - # above. + # Acquire the default compiler search order. This will vary based on os_name. cc_search_list=$(get_cc_search_list) - echo "${script_name}: C compiler search list is: ${cc_search_list}." - - # Find a working C compiler. - found_cc=$(select_tool "${cc_search_list}" "${CC}") - - # If we didn't find any working C compilers, we print an error message. - if [ -z "${found_cc}" ]; then - echo "${script_name}: *** Could not find working C compiler! Cannot continue." - exit 1 - fi - - echo "${script_name}: using '${found_cc}' C compiler." - - - # -- Find a C++ compiler --------------------------------------------------- - - # Acquire the compiler search order. This will vary based on the os - # found above. - cxx_search_list=$(get_cxx_search_list) - - echo "${script_name}: C++ compiler search list is: ${cxx_search_list}." - - # Find a working C++ compiler. NOTE: We can reuse the select_tool() - # function since it is written in a way that is general-purpose. - found_cxx=$(select_tool "${cxx_search_list}" "${CXX}") + # Select a C compiler from the default list, or from CC if it refers to a + # valid binary. + select_tool_w_env "${cc_search_list}" "${CC}" "CC" \ + "C compiler" "yes" found_cc - # If we didn't find any working C++ compilers, we print an error message. - if [ -z "${found_cxx}" ]; then - echo "${script_name}: Could not find working C++ compiler! C++ will not be available in sandbox." - found_cxx="c++notfound" + # Also check the compiler to see if we are (cross-)compiling for Windows + if ${found_cc} -dM -E - < /dev/null 2> /dev/null | grep -q _WIN32; then + is_win=yes fi - echo "${script_name}: using '${found_cxx}' C++ compiler (for sandbox only)." - # -- Check the compiler version -------------------------------------------- @@ -2114,9 +2837,11 @@ main() # Check the compiler's version. Certain versions of certain compilers # will preclude building certain sub-configurations, which are added - # to a blacklist. + # to a blacklist. We also make note of certain version ranges that + # will be useful to know about later. get_compiler_version check_compiler + check_compiler_version_ranges # Now check the assembler's ability to assemble code. Older versions # of binutils may not be aware of certain instruction sets. Those @@ -2135,6 +2860,57 @@ main() fi + # -- Find a C++ compiler --------------------------------------------------- + + # Acquire the default C++ compiler search order. This will vary based on + # os_name. + cxx_search_list=$(get_cxx_search_list) + + # Select a C compiler from the default list, or from CC if it refers to a + # valid binary. + select_tool_w_env "${cxx_search_list}" "${CXX}" "CXX" \ + "C++ compiler" "no" found_cxx + + + # -- Find a Fortran compiler ----------------------------------------------- + + # Acquire the default Fortran compiler search order. + fc_search_list=$(get_fc_search_list) + + # Select a Fortran compiler from the default list, or from FC if it refers + # to a valid binary. + # NOTE: A Fortran compiler is not necessary for building BLIS. The only + # reason we might want to query it is to detect the style of returning + # complex values from functions. The 'gnu' style returns complex values + # from functions normally, via the C language return statement, while the + # 'intel' style returns them in a "hidden" parameter (inserted by the + # compiler) that precedes all other function parameters. + select_tool_w_env "${fc_search_list}" "${FC}" "FC" \ + "Fortran compiler" "no" found_fc + + + # -- Find a static library archiver ---------------------------------------- + + # Acquire the default archiver search order. + ar_search_list=$(get_ar_search_list) + + # Select an archiver from the default list, or from AR if it refers + # to a valid binary. + select_tool_w_env "${ar_search_list}" "${AR}" "AR" \ + "library archiver" "yes" found_ar + + + # -- Find an archive indexer ----------------------------------------------- + + # Acquire the default archive indexer search order. + ranlib_search_list=$(get_ranlib_search_list) + + # Select an archive indexer from the default list, or from RANLIB if it + # refers to a valid binary. + select_tool_w_env "${ranlib_search_list}" "${RANLIB}" "RANLIB" \ + "archive indexer" "yes" found_ranlib + + # -- Read the configuration registry --------------------------------------- # Make sure the config registry file exists and can be opened. @@ -2239,13 +3015,34 @@ main() # config_name. config_name=$(auto_detect) + # Debugging stuff. When confirming the behavior of auto_detect(), + # it is useful to output ${config_name}, which in theory could be + # set temoprarily to something other than the config_name, such as + # the compilation command. + if [ "${debug_auto_detect}" = "yes" ]; then + echo "auto-detect program compilation command: ${config_name}" + exit 1 + fi + echo "${script_name}: hardware detection driver returned '${config_name}'." + + # If the auto-detect code returned the "generic" string, it means we + # were unable to automatically detect the user's hardware type. While + # this is going to be a rare event, it will likely lead the user to + # experience much lower performance than expected, and thus we will + # warn them about it at the end of the configure output (to increase + # the chances that they see it). + if [ "${config_name}" = "generic" ]; then + + warn_user_generic=1 + else + warn_user_generic=0 + fi else # Use the command line argument as the configuration name. config_name=$1 - #echo "${script_name}: manual configuration requested." echo "${script_name}: manual configuration requested; configuring with '${config_name}'." fi @@ -2329,11 +3126,11 @@ main() reducedclist="${kernel}" - # Otherwise, use the first name. + # Otherwise, use the last name. else - first_config=${configs%% *} - reducedclist="${first_config}" + last_config=${configs##* } + reducedclist="${last_config}" fi # Create a new "kernel:subconfig" pair and add it to the kconfig_map @@ -2361,22 +3158,42 @@ main() # but we have it here just in case. if [ $1 = "auto" ]; then - echo "${script_name}: 'auto-detected configuration '${conf}' is NOT registered!" + echo "${script_name}: 'auto-detected configuration '${config_name}' is NOT registered!" echo "${script_name}: " - echo "${script_name}: *** Cannot continue with unregistered configuration '${conf}'. ***" + echo "${script_name}: *** Cannot continue with unregistered configuration '${config_name}'. ***" echo "${script_name}: " exit 1; else - echo "${script_name}: 'user-specified configuration '${conf}' is NOT registered!" - echo "${script_name}: " - echo "${script_name}: *** Cannot continue with unregistered configuration '${conf}'. ***" - echo "${script_name}: " - exit 1; + # At this point, we know: (a) config_list is empty; and (b) the user + # requested manual configuration. If the config_name given by the + # user is present in the configuration blacklist (config_blist), + # then we can deduce why the config_list is empty: because the only + # subconfig implied by config_name is blacklisted. Thus, we cannot + # proceed. + + if [ $(is_in_list "${config_name}" "${config_blist}") == "true" ]; then + + echo "${script_name}: 'user-specified configuration '${config_name}' is blacklisted!" + echo "${script_name}: " + echo "${script_name}: *** Cannot continue with blacklisted configuration '${config_name}'. ***" + echo "${script_name}: *** Try updating your compiler and/or assembler (binutils) versions. ***" + echo "${script_name}: " + exit 1; + else - fi + # If config_name is NOT present in config_blist, then we know + # that config_list is empty simply because config_name is + # unregistered. + echo "${script_name}: 'user-specified configuration '${config_name}' is NOT registered!" + echo "${script_name}: " + echo "${script_name}: *** Cannot continue with unregistered configuration '${config_name}'. ***" + echo "${script_name}: " + exit 1; + fi + fi else # This branch executes when the configuration is found to be present @@ -2459,54 +3276,49 @@ main() # -- Prepare variables for subsitution into template files ----------------- - # Parse the status of the install prefix and echo feedback. + # Parse the status of the prefix option and echo feedback. if [ -n "${prefix_flag}" ]; then - echo "${script_name}: detected --prefix='${install_prefix_user}'." + echo "${script_name}: detected --prefix='${prefix}'." else - echo "${script_name}: no install prefix option given; defaulting to '${install_prefix_user}'." + echo "${script_name}: no install prefix option given; defaulting to '${prefix}'." fi - # Set initial (candidate) values for the libdir and includedir using the - # install prefix that was determined above. - install_libdir=${install_prefix_user}/lib - install_incdir=${install_prefix_user}/include - install_sharedir=${install_prefix_user}/share + # Parse the status of the exec_prefix option and echo feedback. + if [ -n "${exec_prefix_flag}" ]; then + echo "${script_name}: detected --exec-prefix='${exec_prefix}'." + else + echo "${script_name}: no install exec_prefix option given; defaulting to PREFIX." + fi - # Set the install libdir, if it was specified. Note that this will override - # the default libdir implied by the install prefix, even if both options - # were given. + # Parse the status of the libdir option and echo feedback. if [ -n "${libdir_flag}" ]; then - echo "${script_name}: detected --libdir='${install_libdir_user}'." - install_libdir=${install_libdir_user} + echo "${script_name}: detected --libdir='${libdir}'." else - echo "${script_name}: no install libdir option given; defaulting to PREFIX/lib." + echo "${script_name}: no install libdir option given; defaulting to EXECPREFIX/lib." fi - # Set the install includedir, if it was specified. Note that this will - # override the default includedir implied by the install prefix, even if - # both options were given. - if [ -n "${incdir_flag}" ]; then - echo "${script_name}: detected --includedir='${install_incdir_user}'." - install_incdir=${install_incdir_user} + # Parse the status of the includedir option and echo feedback. + if [ -n "${includedir_flag}" ]; then + echo "${script_name}: detected --includedir='${includedir}'." else echo "${script_name}: no install includedir option given; defaulting to PREFIX/include." fi - # Set the install sharedir, if it was specified. Note that this will - # override the default sharedir implied by the install prefix, even if - # both options were given. + # Parse the status of the sharedir option and echo feedback. if [ -n "${sharedir_flag}" ]; then - echo "${script_name}: detected --sharedir='${install_sharedir_user}'." - install_sharedir=${install_sharedir_user} + echo "${script_name}: detected --sharedir='${sharedir}'." else echo "${script_name}: no install sharedir option given; defaulting to PREFIX/share." fi # Echo the installation directories that we settled on. echo "${script_name}: final installation directories:" - echo "${script_name}: libdir: ${install_libdir}" - echo "${script_name}: includedir: ${install_incdir}" - echo "${script_name}: sharedir: ${install_sharedir}" + echo "${script_name}: prefix: "${prefix} + echo "${script_name}: exec_prefix: "${exec_prefix} + echo "${script_name}: libdir: "${libdir} + echo "${script_name}: includedir: "${includedir} + echo "${script_name}: sharedir: "${sharedir} + echo "${script_name}: NOTE: the variables above can be overridden when running make." # Check if CFLAGS is non-empty. if [ -n "${CFLAGS}" ]; then @@ -2573,6 +3385,36 @@ main() exit 1 fi + # Check if the "export shared" flag was specified. + if [ "x${export_shared}" = "xall" ]; then + if [ "x${enable_shared}" = "xyes" ]; then + echo "${script_name}: exporting all symbols within shared library." + else + echo "${script_name}: ignoring request to export all symbols within shared library." + fi + elif [ "x${export_shared}" = "xpublic" ]; then + if [ "x${enable_shared}" = "xyes" ]; then + echo "${script_name}: exporting only public symbols within shared library." + fi + else + echo "${script_name}: *** Invalid argument '${export_shared}' to --export-shared option given." + echo "${script_name}: *** Please use 'public' or 'all'." + exit 1 + fi + + # Check if we are building with or without operating system support. + if [ "x${enable_system}" = "xyes" ]; then + echo "${script_name}: enabling operating system support." + enable_system_01=1 + else + echo "${script_name}: disabling operating system support." + echo "${script_name}: WARNING: all threading will be disabled!" + enable_system_01=0 + + # Force threading to be disabled. + threading_model='off' + fi + # Check the threading model flag and standardize its value, if needed. # NOTE: 'omp' is deprecated but still supported; 'openmp' is preferred. enable_openmp='no' @@ -2594,9 +3436,11 @@ main() enable_pthreads='yes' enable_pthreads_01=1 threading_model="pthreads" # Standardize the value. - elif [ "x${threading_model}" = "xno" ] || + elif [ "x${threading_model}" = "xoff" ] || + [ "x${threading_model}" = "xno" ] || [ "x${threading_model}" = "xnone" ]; then echo "${script_name}: threading is disabled." + threading_model="off" else echo "${script_name}: *** Unsupported threading model: ${threading_model}." exit 1 @@ -2707,6 +3551,20 @@ main() enable_mixed_dt_extra_mem_01=0 enable_mixed_dt_01=0 fi + if [ "x${enable_sup_handling}" = "xyes" ]; then + echo "${script_name}: small matrix handling is enabled." + enable_sup_handling_01=1 + else + echo "${script_name}: small matrix handling is disabled." + enable_sup_handling_01=0 + fi + if [ "x${enable_trsm_preinversion}" = "xyes" ]; then + echo "${script_name}: trsm diagonal element pre-inversion is enabled." + enable_trsm_preinversion_01=1 + else + echo "${script_name}: trsm diagonal element pre-inversion is disabled." + enable_trsm_preinversion_01=0 + fi # Report integer sizes. if [ "x${int_type_size}" = "x32" ]; then @@ -2731,6 +3589,57 @@ main() exit 1 fi + # Check whether we should use AMD-customized versions of certain framework + # files. + if [ "x${enable_amd_frame_tweaks}" = "xyes" ]; then + + echo "${script_name}: AMD-specific framework files will be considered." + echo "${script_name}: checking eligibility of target configuration." + + # Make sure we are targeting either one of the zen subconfigs or the + # amd64 umbrella family. + uconf=$(echo ${config_name} | grep -c 'zen\|amd64') + + if [[ $uconf == 0 ]]; then + echo "${script_name}: target configuration '${config_name}' is not eligible." + echo "${script_name}: disabling AMD-specific framework files." + enable_amd_frame_tweaks='no' + else + echo "${script_name}: target configuration '${config_name}' is eligible." + echo "${script_name}: enabling AMD-specific framework files." + fi + else + echo "${script_name}: AMD-specific framework files will not be considered." + fi + + # Check if addons were given. + if [ -n "${addon_flag}" ]; then + + # Remove duplicates in the addon list, if they exist. + addon_list=$(rm_duplicate_words_simple "${addon_list}") + + echo "${script_name}: configuring with addons:" + + for addon in ${addon_list}; do + + echo "${script_name}: ${addon_dir}/${addon}" + + addon_fullpath="${addon_dirpath}/${addon}" + + if [ ! -d "${addon_fullpath}" ]; then + echo "${script_name}: requested addon sub-directory does not exist! Cannot continue." + echo "${script_name}: *** Please verify addon existence and name." + exit 1 + fi + done + + enable_addons_01=1 + else + echo "${script_name}: configuring with no addons." + + enable_addons_01=0 + fi + # Check if a sandbox was given. if [ -n "${sandbox_flag}" ]; then @@ -2754,24 +3663,70 @@ main() enable_sandbox_01=0 fi + # Check the method used for returning complex numbers. + if [ "x${complex_return}" = "xdefault" ]; then + + # If we prevoiusly found a Fortran compiler, let's query it to see what + # kind of complex return type it uses (gnu or intel). The 'gnu' style + # returns complex values from functions normally, via the C language + # return statement, while the 'intel' style returns them in a "hidden" + # parameter (inserted by the compiler) that precedes all other function + # parameters. + if [ -n "${found_fc}" ]; then + + # Query the full vendor version string output. This includes the + # version number along with (potentially) a bunch of other textual + # clutter. + # NOTE: This maybe should use merged stdout/stderr rather than only + # stdout. But it works for now. + vendor_string="$(${FC} --version 2>/dev/null)" + + # Query the compiler "vendor" (ie: the compiler's simple name). + # The last part ({ read first rest ; echo $first ; }) is a workaround + # to OS X's egrep only returning the first match. + fc_vendor=$(echo "${vendor_string}" | egrep -o 'ifort|GNU' | { read first rest ; echo $first ; }) + + if [ "x${fc_vendor}" = "xifort" ]; then + complex_return='intel' + elif [ "x${fc_vendor}" = "xGNU" ]; then + complex_return='gnu' + else + echo "${script_name}: unable to determine Fortran compiler vendor!" + complex_return='gnu' + fi + else + complex_return='gnu' + fi + fi + + if [ "x${complex_return}" = "xgnu" ]; then + complex_return_intel01='0' + elif [ "x${complex_return}" = "xintel" ]; then + complex_return_intel01='1' + else + echo "${script_name}: unknown complex return type \"${complex_return}\"! Cannot continue." + echo "${script_name}: *** Acceptable values are \"gnu\" and \"intel\"." + exit 1 + fi + + echo "${script_name}: configuring complex return type as \"${complex_return}\"." # Variables that may contain forward slashes, such as paths, need extra # escaping when used in sed commands. We insert those extra escape # characters here so that the sed commands below do the right thing. - os_name_esc=$(echo "${os_name}" | sed 's/\//\\\//g') - install_libdir_esc=$(echo "${install_libdir}" | sed 's/\//\\\//g') - install_incdir_esc=$(echo "${install_incdir}" | sed 's/\//\\\//g') - install_sharedir_esc=$(echo "${install_sharedir}" | sed 's/\//\\\//g') - dist_path_esc=$(echo "${dist_path}" | sed 's/\//\\\//g') - cc_esc=$(echo "${found_cc}" | sed 's/\//\\\//g') - cxx_esc=$(echo "${found_cxx}" | sed 's/\//\\\//g') - #sandbox_relpath_esc=$(echo "${sandbox_relpath}" | sed 's/\//\\\//g') - - # For RANLIB, if the variable is not set, we use a default value of - # 'ranlib'. - ranlib_esc=$(echo "${RANLIB:-ranlib}" | sed 's/\//\\\//g') - # For AR, if the variable is not set, we use a default value of 'ar'. - ar_esc=$(echo "${AR:-ar}" | sed 's/\//\\\//g') + os_name_esc=$(echo "${os_name}" | sed 's/\//\\\//g') + prefix_esc=$(echo "${prefix}" | sed 's/\//\\\//g') + exec_prefix_esc=$(echo "${exec_prefix}" | sed 's/\//\\\//g') + libdir_esc=$(echo "${libdir}" | sed 's/\//\\\//g') + includedir_esc=$(echo "${includedir}" | sed 's/\//\\\//g') + sharedir_esc=$(echo "${sharedir}" | sed 's/\//\\\//g') + dist_path_esc=$(echo "${dist_path}" | sed 's/\//\\\//g') + cc_esc=$(echo "${found_cc}" | sed 's/\//\\\//g') + cxx_esc=$(echo "${found_cxx}" | sed 's/\//\\\//g') + ar_esc=$(echo "${found_ar}" | sed 's/\//\\\//g') + ranlib_esc=$(echo "${found_ranlib}" | sed 's/\//\\\//g') + python_esc=$(echo "${found_python}" | sed 's/\//\\\//g') + libpthread_esc=$(echo "${LIBPTHREAD--lpthread}" | sed 's/\//\\\//g') cflags_preset_esc=$(echo "${cflags_preset}" | sed 's/\//\\\//g') ldflags_preset_esc=$(echo "${ldflags_preset}" | sed 's/\//\\\//g') @@ -2779,7 +3734,13 @@ main() # For Windows builds, clear the libpthread_esc variable so that # no pthreads library is substituted into config.mk. (Windows builds # employ an implementation of pthreads that is internal to BLIS.) - if [ $is_win = yes ]; then + if [[ "$is_win" == "yes" && "$cc_vendor" == "clang" ]]; then + libpthread_esc= + fi + + # We also clear the libpthread_esc variable for systemless builds + # (--disable-system). + if [[ "$enable_system" == "no" ]]; then libpthread_esc= fi @@ -2793,7 +3754,7 @@ main() # Create a #define for the configuration family (config_name). uconf=$(echo ${config_name} | tr '[:lower:]' '[:upper:]') config_name_define="#define BLIS_FAMILY_${uconf}\n" - + # Create a list of #defines, one for each configuration in config_list. config_list_defines="" for conf in ${config_list}; do @@ -2818,16 +3779,25 @@ main() kernel_list_defines="${kernel_list_defines}#define ${kernel_define}\n" done + # Create a list of #includes, one for each addon in addon_list. + addon_list_includes="" + for addon in ${addon_list}; do + + # Create a #define and add it to the running list. + addon_header="\"${addon}.h\"" + addon_list_includes="${addon_list_includes}#include ${addon_header}\n" + done + # -- Determine whether we are performing an out-of-tree build -------------- - if [ ${dist_path} != "./" ]; then + if [ "${dist_path}" != "./" ]; then # At this point, we know the user did not run "./configure". But we # have not yet ruled out "/configure" or some # equivalent # that uses relative paths. To further rule out these possibilities, # we create a dummy file in the current build directory. - touch ./${dummy_file} + touch "./${dummy_file}" # If the dummy file we just created in the current directory does not # appear in the source distribution path, then we are in a different @@ -2845,7 +3815,7 @@ main() fi - # -- Instantiate config.mk, bli_config.h files from templates -------------- + # -- Instantiate config.mk file from template ------------------------------ # Begin substituting information into the config_mk_in file, outputting # to config_mk_out. @@ -2862,30 +3832,47 @@ main() | sed -e "s/@is_win@/${is_win}/g" \ | sed -e "s/@dist_path@/${dist_path_esc}/g" \ | sed -e "s/@CC_VENDOR@/${cc_vendor}/g" \ + | sed -e "s/@gcc_older_than_4_9_0@/${gcc_older_than_4_9_0}/g" \ + | sed -e "s/@gcc_older_than_6_1_0@/${gcc_older_than_6_1_0}/g" \ + | sed -e "s/@gcc_older_than_9_1_0@/${gcc_older_than_9_1_0}/g" \ + | sed -e "s/@gcc_older_than_10_1_0@/${gcc_older_than_10_1_0}/g" \ + | sed -e "s/@clang_older_than_9_0_0@/${clang_older_than_9_0_0}/g" \ + | sed -e "s/@clang_older_than_12_0_0@/${clang_older_than_12_0_0}/g" \ + | sed -e "s/@aocc_older_than_2_0_0@/${aocc_older_than_2_0_0}/g" \ + | sed -e "s/@aocc_older_than_3_0_0@/${aocc_older_than_3_0_0}/g" \ | sed -e "s/@CC@/${cc_esc}/g" \ | sed -e "s/@CXX@/${cxx_esc}/g" \ - | sed -e "s/@RANLIB@/${ranlib_esc}/g" \ | sed -e "s/@AR@/${ar_esc}/g" \ + | sed -e "s/@RANLIB@/${ranlib_esc}/g" \ + | sed -e "s/@PYTHON@/${python_esc}/g" \ | sed -e "s/@libpthread@/${libpthread_esc}/g" \ | sed -e "s/@cflags_preset@/${cflags_preset_esc}/g" \ | sed -e "s/@ldflags_preset@/${ldflags_preset_esc}/g" \ | sed -e "s/@debug_type@/${debug_type}/g" \ + | sed -e "s/@enable_system@/${enable_system}/g" \ | sed -e "s/@threading_model@/${threading_model}/g" \ - | sed -e "s/@install_libdir@/${install_libdir_esc}/g" \ - | sed -e "s/@install_incdir@/${install_incdir_esc}/g" \ - | sed -e "s/@install_sharedir@/${install_sharedir_esc}/g" \ + | sed -e "s/@prefix@/${prefix_esc}/g" \ + | sed -e "s/@exec_prefix@/${exec_prefix_esc}/g" \ + | sed -e "s/@libdir@/${libdir_esc}/g" \ + | sed -e "s/@includedir@/${includedir_esc}/g" \ + | sed -e "s/@sharedir@/${sharedir_esc}/g" \ | sed -e "s/@enable_verbose@/${enable_verbose}/g" \ | sed -e "s/@configured_oot@/${configured_oot}/g" \ | sed -e "s/@enable_arg_max_hack@/${enable_arg_max_hack}/g" \ | sed -e "s/@enable_static@/${enable_static}/g" \ | sed -e "s/@enable_shared@/${enable_shared}/g" \ + | sed -e "s/@enable_rpath@/${enable_rpath}/g" \ + | sed -e "s/@export_shared@/${export_shared}/g" \ | sed -e "s/@enable_blas@/${enable_blas}/g" \ | sed -e "s/@enable_cblas@/${enable_cblas}/g" \ + | sed -e "s/@enable_amd_frame_tweaks@/${enable_amd_frame_tweaks}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind}/g" \ | sed -e "s/@pragma_omp_simd@/${pragma_omp_simd}/g" \ + | sed -e "s/@addon_list@/${addon_list}/g" \ | sed -e "s/@sandbox@/${sandbox}/g" \ > "${config_mk_out_path}" - + + # -- Instantiate bli_config.h file from template --------------------------- # Begin substituting information into the bli_config_h_in file, outputting # to bli_config_h_out. NOTE: We use perl instead of sed because the version @@ -2897,6 +3884,7 @@ main() | perl -pe "s/\@config_name_define\@/${config_name_define}/g" \ | perl -pe "s/\@config_list_defines\@/${config_list_defines}/g" \ | perl -pe "s/\@kernel_list_defines\@/${kernel_list_defines}/g" \ + | sed -e "s/@enable_system@/${enable_system_01}/g" \ | sed -e "s/@enable_openmp@/${enable_openmp_01}/g" \ | sed -e "s/@enable_pthreads@/${enable_pthreads_01}/g" \ | sed -e "s/@enable_jrir_slab@/${enable_jrir_slab_01}/g" \ @@ -2910,12 +3898,26 @@ main() | sed -e "s/@enable_cblas@/${enable_cblas_01}/g" \ | sed -e "s/@enable_mixed_dt@/${enable_mixed_dt_01}/g" \ | sed -e "s/@enable_mixed_dt_extra_mem@/${enable_mixed_dt_extra_mem_01}/g" \ + | sed -e "s/@enable_sup_handling@/${enable_sup_handling_01}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind_01}/g" \ + | sed -e "s/@enable_trsm_preinversion@/${enable_trsm_preinversion_01}/g" \ | sed -e "s/@enable_pragma_omp_simd@/${enable_pragma_omp_simd_01}/g" \ | sed -e "s/@enable_sandbox@/${enable_sandbox_01}/g" \ | sed -e "s/@enable_shared@/${enable_shared_01}/g" \ + | sed -e "s/@complex_return_intel@/${complex_return_intel01}/g" \ > "${bli_config_h_out_path}" + # -- Instantiate bli_addon.h file from template ---------------------------- + + # Begin substituting information into the bli_addon_h_in file, outputting + # to bli_addon_h_out. NOTE: We use perl instead of sed because the version + # of sed used on OS X is old and does not handle the '\n' character + # intuitively, which was used when constructing ${addon_list_includes}. + echo "${script_name}: creating ${bli_addon_h_out_path} from ${bli_addon_h_in_path}" + cat "${bli_addon_h_in_path}" \ + | perl -pe "s/\@addon_list_includes\@/${addon_list_includes}/g" \ + | sed -e "s/@enable_addons@/${enable_addons_01}/g" \ + > "${bli_addon_h_out_path}" # -- Create top-level object directories ----------------------------------- @@ -2928,7 +3930,6 @@ main() obj_config_dirpath="${base_obj_dirpath}/${config_dir}" - #echo "${script_name}: creating ${obj_config_dirpath}" mkdir -p ${obj_config_dirpath} for conf in ${config_list}; do echo "${script_name}: creating ${obj_config_dirpath}/${conf}" @@ -2938,7 +3939,6 @@ main() obj_kernels_dirpath="${base_obj_dirpath}/${kernels_dir}" - #echo "${script_name}: creating ${obj_kernels_dirpath}" mkdir -p ${obj_kernels_dirpath} for kern in ${kernel_list}; do echo "${script_name}: creating ${obj_kernels_dirpath}/${kern}" @@ -2948,7 +3948,6 @@ main() obj_refkern_dirpath="${base_obj_dirpath}/${refkern_dir}" - #echo "${script_name}: creating ${obj_refkern_dirpath}" mkdir -p ${obj_refkern_dirpath} for conf in ${config_list}; do echo "${script_name}: creating ${obj_refkern_dirpath}/${conf}" @@ -2962,6 +3961,17 @@ main() mkdir -p ${obj_frame_dirpath} + if [ -n "${addon_flag}" ]; then + + obj_addon_dirpath="${base_obj_dirpath}/${addon_dir}" + + for addon in ${addon_list}; do + echo "${script_name}: creating ${obj_addon_dirpath}/${addon}" + mkdir -p ${obj_addon_dirpath}/${addon} + done + fi + + if [ -n "${sandbox_flag}" ]; then obj_sandbox_dirpath="${base_obj_dirpath}/${sandbox_dir}" @@ -2989,6 +3999,7 @@ main() echo "${script_name}: creating ${base_lib_dirpath}" mkdir -p ${base_lib_dirpath} + # Create include directory (if it does not already exist). base_include_dirpath="${include_dirpath}/${config_name}" @@ -3043,6 +4054,16 @@ main() echo "${script_name}: mirroring ${frame_dirpath} to ${obj_frame_dirpath}" ${mirror_tree_sh} ${frame_dirpath} ${obj_frame_dirpath} + # Mirror the chosen addon source tree to its object sub-directory. + if [ -n "${addon_flag}" ]; then + + for addon in ${addon_list}; do + + echo "${script_name}: mirroring ${addon_dirpath}/${addon} to ${obj_addon_dirpath}/${addon}" + ${mirror_tree_sh} "${addon_dirpath}/${addon}" "${obj_addon_dirpath}/${addon}" + done + fi + # Mirror the chosen sandbox source tree to its object sub-directory. if [ -n "${sandbox_flag}" ]; then @@ -3129,6 +4150,25 @@ main() ${gen_make_frags_dirpath}/suffix_list \ ${gen_make_frags_dirpath}/ignore_list + # Generate makefile fragments in the addon sub-directory. + if [ -n "${addon_flag}" ]; then + + for addon in ${addon_list}; do + + echo "${script_name}: creating makefile fragments in ${obj_addon_dirpath}/${addon}" + ${gen_make_frags_sh} \ + -h -r -v0 \ + -o ${script_name} \ + -p 'ADDON' \ + ${addon_dirpath}/${addon} \ + ${obj_addon_dirpath}/${addon} \ + ${gen_make_frags_dirpath}/fragment.mk \ + ${gen_make_frags_dirpath}/suffix_list \ + ${gen_make_frags_dirpath}/ignore_list + done + fi + + # Generate makefile fragments in the sandbox sub-directory. if [ -n "${sandbox_flag}" ]; then @@ -3168,6 +4208,23 @@ main() exit 1 fi + # If 'blis.pc.in' symlink does not already exist in the current + # directory, create a symbolic link to it. If one does exist, we + # use -f to force creation of a new link. + if [ ! -e "./blis.pc.in" ]; then + + echo "${script_name}: creating symbolic link to blis.pc.in." + ln -s "${dist_path}/blis.pc.in" + + elif [ -h "./blis.pc.in" ]; then + echo "${script_name}: symbolic link to blis.pc.in already exists; forcing creation of new link." + ln -sf "${dist_path}/blis.pc.in" + else + echo "${script_name}: Non-symbolic link file or directory 'blis.pc.in' blocks creation of symlink." + echo "${script_name}: *** Please remove this entity and re-run configure." + exit 1 + fi + # If 'common.mk' symlink does not already exist in the current # directory, create a symbolic link to it. If one does exist, we # use -f to force creation of a new link. @@ -3208,6 +4265,18 @@ main() echo "${script_name}: configured to build within top-level directory of source distribution." fi + if [ "${warn_user_generic}" = "1" ]; then + + echo "${script_name}: " + echo "${script_name}: *** Unable to automatically detect hardware type! ***" + echo "${script_name}: " + echo "${script_name}: NOTE: configure was unable to identify a subconfiguration" + echo "${script_name}: optimized for your hardware. As a result, the 'generic'" + echo "${script_name}: subconfiguration (with low-performance reference kernels)" + echo "${script_name}: will be used. For support, please open an issue on GitHub" + echo "${script_name}: at https://github.com/flame/blis/issues." + echo "${script_name}: " + fi # Exit peacefully. return 0 diff --git a/docs/Addons.md b/docs/Addons.md new file mode 100644 index 0000000000..bd4799fb76 --- /dev/null +++ b/docs/Addons.md @@ -0,0 +1,231 @@ +## Contents + +* **[Introduction](Addons.md#introduction)** +* **[Enabling addons](Addons.md#enabling-addons)** +* **[Addon rules](Addons.md#addon-rules)** +* **[Caveats](Addons.md#caveats)** +* **[Known issues](Addons.md#known-issues)** +* **[Conclusion](Addons.md#conclusion)** + + +## Introduction + +This file briefly describes the requirements for enabling or creating a +custom BLIS *addon*. + +Simply put, an addon in BLIS provides additional APIs, operations, and/or +implementations that may be useful to certain users. An addon can be +thought of as a standalone extension of BLIS that does not depend on any +other addon, although addons may utilize existing functionality or kernels +within the core framework. + +By definition, an addon should *never* provide APIs that conflict with +the interfaces that belong to either the [typed API](BLISTypedAPI.md) or the +[object API](BLISObjectAPI.md). Thus, you'll never have to worry about a +properly constructed (and properly functioning) addon interfering with or +otherwise changing core BLIS functionality. + +How does an addon differ from a [sandbox](Sandboxes.md)? Great question! +Sometimes you want to include additional BLIS-like functionality that does +not relate directly to `gemm` or any other BLIS operation. +(By contrast, a sandbox requires you to implement `gemm` whether you want +to or not.) +Furthermore, you may wish to enable multiple addons simultaneously. +(By contrast, only one sandbox may be enabled at a time.) +Thus, the addon feature provides additional flexibility to some +users in a way that sandboxes cannot, while still providing many of the +conveniences of sandboxes. + +## Enabling an addon + +To enable an existing addon at configure-time, you simply specify it as an +option to `configure`. Either of the following usages are accepted: +``` +$ ./configure --enable-addon=foobar auto +$ ./configure -a foobar auto +``` +Here, we tell `configure` that we want to use the `foobar` addon, which +corresponds to a subdirectory of the `addon` directory named `foobar`. +(Reminder: the `auto` argument is the configuration target and +unrelated to addons.) + +You may also enable multiple addons within the same build of BLIS: +``` +$ ./configure -a foobar -a thing1 -a thing2 auto +``` +Note that the default behavior of `configure` is that no addons are enabled. + +As `configure` runs, you should get output that includes lines +similar to: +``` +configure: configuring with addons: +configure: addon/foobar +configure: addon/thing1 +configure: addon/thing2 +``` +And when you build BLIS, the addon source code will be among the last files to +be compiled: +``` +Compiling obj/haswell/addon/foobar/foobar.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing1/thing1.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing1/thing1_api.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing2/thing2_api.o ('haswell' CFLAGS for addons) +... +``` +That's it! After the BLIS library is built, it will contain your chosen +addons. You can always confirm this by using `nm` to confirm the presence +of your API symbols: +``` +$ nm lib/haswell/libblis.a | grep foobar +foobar.o: +0000000000000000 T foobar +``` + +## Addon rules + +Please follow these guidelines for the best developer experience when +creating addons. + +1. As with sandboxes, you don't need to worry about creating makefiles. The +BLIS build system will take care of this for you. :) By configuring BLIS with +an addon enabled, `make` will scan your addon subdirectory and compile +all of its source code using similar compilation rules as were used for the rest +of the framework. In addition, the compilation command line will automatically +contain one `-I` option for every subdirectory in your addon, +so it doesn't matter where in your addon directory hierarchy you place your +header files -- they will be found! + +2. We recommend that you write your addon in C99. While you *may* use C++11 +to implement your addon, you should provide a C99 wrapper API to your +implementation so that others can interface with it. There is no guarantee +that the end-user will be using a C++11 compiler, and therefore you should +limit the definitions in your addon header to those that are C99 compliant. +If you write your addon in C++11, you must use one of the BLIS-approved file +extensions for your source files (`.cc`, `.cpp`, `.cxx`) and your local +header files (`.hh`, `.hpp`, `.hxx`). +Note that `blis.h` already contains all of its definitions inside of an +`extern "C"` block, so you should be able to `#include "blis.h"` from your +C++11 source code without any issues. + +3. All of your code related to the addon should reside within the named +addon directory, or some subdirectory therein. If your addon requires +new kernels, you should add kernel source code to an appropriate +microarchitecture-specific subdirectory within the top-level `kernels` +directory so that they are compiled with the correct +microarchitecture-specific optimization flags. + +4. If your addon is named `foobar`, the BLIS build system will expect to +find a header called `foobar.h` somewhere in the `addon/foobar` directory +(or one of its subdirectories). This `foobar.h` header will automatically +be inlined into the monolithic `blis.h` header that is produced by the +BLIS build system. `foobar.h` may `#include` other local headers, each of +which will also (recursively) get inlined into `blis.h`. However, you may +choose to omit some local addon headers from `foobar.h.` You might do this, +for example, because those headers define things that are not needed in +order for the end user to call your addon code. + +5. Your addon APIs will always be available within static library builds of +BLIS, but if you want your addon APIs to be exported as public APIs within +*shared* library builds of BLIS, you'll need to annotate the prototypes +accordingly. (BLIS makes its shared library symbols private by default; this +allows us to export only those functions that we consider to be part of the +public APIs.) This annotation can be done by prefixing function prototypes +with the `BLIS_EXPORT_ADDON` macro as follows: +```c +BLIS_EXPORT_ADDON void foobar_calc( void* a, void* b ); +``` + +6. Do not define any symbols in your addon that conflict with any symbols within +the core framework. For example, don't define a function called `bli_copym()` +in your addon since that function is already defined within BLIS. + +7. Do not define any symbols in your addon that conflict with any symbols within +the C99 standard libraries/headers. For example, don't define a function called +`printf()` since that function is already defined within the C99 standard library. + +8. *Try* to not define any symbols in your addon that conflict with symbols in any +other addon, unless your addon is meant to serve as an alternative to the +conflicting addon, in which case conflicting symbol names is okay (since you +will presumably never build with both addons enabled). + +9. When choosing names for your addon files, avoid source filenames that already +exist within BLIS. For example, don't name one of your files `bli_obj.c` +since that file would compile into `bli_obj.o`, which will have already been +placed into the library by the build system. + +10. Similarly, avoid header filenames that already exist within BLIS or C99. +For example, don't name one of your header files `bli_obj.h` since that file +already exists in BLIS. Also, don't name one of your header files `math.h` +since that name would conflict with the `math.h` defined by C99. (This also +means you shouldn't name your addon `math` since normally that name would +require that you provide a `math.h` header inside the addon directory.) + +If you follow these rules, you will be much more likely to have a pleasant +experience integrating your BLIS addon into the larger framework. + +## Caveats + +Notice that the BLIS addons are limited in what they can accomplish. Generally +speaking, addons cannot change existing implementations within BLIS. Instead, +addons aim to provide a way to quickly augment BLIS with additional bundles of +code that extend BLIS's set of functionality in some interesting way. If you +want to define new BLAS-like functions, but don't know where to start, creating +a new addon is an appropriate place to start experimenting. If you want to +change or refactor existing BLIS code, an addon is probably not suited for your +needs. + +Another important limitation is the fact that the build system currently uses +"framework `CFLAGS`" when compiling the addon source files. These are the same +`CFLAGS` used when compiling general framework source code, +``` +# Example framework CFLAGS used by 'haswell' sub-configuration +-O2 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 +-D_POSIX_C_SOURCE=200112L -Iinclude/haswell -I./frame/3/ +-I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.8.1-195\" -fvisibility=hidden +``` +which are likely more general-purpose than the `CFLAGS` used for, say, +optimized kernels or even reference kernels: +``` +# Example optimized kernel CFLAGS used by 'haswell' sub-configuration +-O3 -fomit-frame-pointer -mavx2 -mfma -mfpmath=sse -march=haswell -Wall +-Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L +-Iinclude/haswell -I./frame/3/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ +-I./frame/include -DBLIS_VERSION_STRING=\"0.8.1-195\" -fvisibility=hidden +``` +(To see precisely which flags are being employed for any given file, enable +verbosity at compile-time via `make V=1`.) Compiling addons with these more +versatile `CFLAGS` compiler options means that we only need to compile one +instance of each addon source file, even when targeting multiple +configurations (for example, via `./configure x86_64`). However, it also means +that addons are not ideal for microkernels, as they sometimes need additional +compiler flags in order to +yield the highest performance. If you have a new microkernel you would like to +use within an addon, you can always develop it within that addon. However, +once it is stable and ready for use by others, it's best to move the kernel(s) +to the appropriate microarchitecture-specific subdirectory of the `kernels` +directory the kernel(s). This will allow the kernel to be compiled with the +appropriate microarchitecture-specific compiler flags. +Please see the +[Configuration Guide](ConfigurationHowTo) +for more details, and when in doubt, please don't be shy about seeking +guidance from BLIS developers by opening a +[new issue](https://github.com/flame/blis/issues) or sending a message to the +[blis-devel](http://groups.google.com/d/forum/blis-devel) mailing list. + +Notwithstanding these limitations, hopefully you still find BLIS addons +useful! + +## Known issues + +* None yet. + +## Conclusion + +If you encounter any problems, please open +a new [issue on GitHub](https://github.com/flame/blis/issues). + +If you are unsure about how something works, you can still open an issue. Or, you +can send a message to +[blis-devel](https://groups.google.com/d/forum/blis-devel) mailing list. + diff --git a/docs/BLISObjectAPI.md b/docs/BLISObjectAPI.md index e68cc67749..5e8ed3d8fb 100644 --- a/docs/BLISObjectAPI.md +++ b/docs/BLISObjectAPI.md @@ -1,6 +1,7 @@ # Contents * **[Contents](BLISObjectAPI.md#contents)** +* **[Operation index](BLISObjectAPI.md#operation-index)** * **[Introduction](BLISObjectAPI.md#introduction)** * [BLIS types](BLISObjectAPI.md#blis-types) * [Integer-based types](BLISObjectAPI.md#integer-based-types) @@ -15,8 +16,9 @@ * **[Object management](BLISObjectAPI.md#object-management)** * [Object creation function reference](BLISObjectAPI.md#object-creation-function-reference) * [Object accessor function reference](BLISObjectAPI.md#object-accessor-function-reference) + * [Object mutator function reference](BLISObjectAPI.md#object-mutator-function-reference) + * [Other object function reference](BLISObjectAPI.md#other-object-function-reference) * **[Computational function reference](BLISObjectAPI.md#computational-function-reference)** - * [Operation index](BLISObjectAPI.md#operation-index) * [Level-1v operations](BLISObjectAPI.md#level-1v-operations) * [Level-1d operations](BLISObjectAPI.md#level-1d-operations) * [Level-1m operations](BLISObjectAPI.md#level-1m-operations) @@ -24,14 +26,37 @@ * [Level-2 operations](BLISObjectAPI.md#level-2-operations) * [Level-3 operations](BLISObjectAPI.md#level-3-operations) * [Utility operations](BLISObjectAPI.md#utility-operations) - * [Level-3 microkernels](BLISObjectAPI.md#level-3-microkernels) * **[Query function reference](BLISObjectAPI.md#query-function-reference)** * [General library information](BLISObjectAPI.md#general-library-information) * [Specific configuration](BLISObjectAPI.md#specific-configuration) * [General configuration](BLISObjectAPI.md#general-configuration) * [Kernel information](BLISObjectAPI.md#kernel-information) + * [Clock functions](BLISObjectAPI.md#clock-functions) * **[Example code](BLISObjectAPI.md#example-code)** + + +# Operation index + +This index provides a quick way to jump directly to the description for each operation discussed later in the [Computational function reference](BLISObjectAPI.md#computational-function-reference) section: + + * **[Level-1v](BLISObjectAPI.md#level-1v-operations)**: Operations on vectors: + * [addv](BLISObjectAPI.md#addv), [amaxv](BLISObjectAPI.md#amaxv), [axpyv](BLISObjectAPI.md#axpyv), [axpbyv](BLISObjectAPI.md#axpbyv), [copyv](BLISObjectAPI.md#copyv), [dotv](BLISObjectAPI.md#dotv), [dotxv](BLISObjectAPI.md#dotxv), [invertv](BLISObjectAPI.md#invertv), [scal2v](BLISObjectAPI.md#scal2v), [scalv](BLISObjectAPI.md#scalv), [setv](BLISObjectAPI.md#setv), [setrv](BLISObjectAPI.md#setrv), [setiv](BLISObjectAPI.md#setiv), [subv](BLISObjectAPI.md#subv), [swapv](BLISObjectAPI.md#swapv), [xpbyv](BLISObjectAPI.md#xpbyv) + * **[Level-1d](BLISObjectAPI.md#level-1d-operations)**: Element-wise operations on matrix diagonals: + * [addd](BLISObjectAPI.md#addd), [axpyd](BLISObjectAPI.md#axpyd), [copyd](BLISObjectAPI.md#copyd), [invertd](BLISObjectAPI.md#invertd), [scald](BLISObjectAPI.md#scald), [scal2d](BLISObjectAPI.md#scal2d), [setd](BLISObjectAPI.md#setd), [setid](BLISObjectAPI.md#setid), [shiftd](BLISObjectAPI.md#shiftd), [subd](BLISObjectAPI.md#subd), [xpbyd](BLISObjectAPI.md#xpbyd) + * **[Level-1m](BLISObjectAPI.md#level-1m-operations)**: Element-wise operations on matrices: + * [addm](BLISObjectAPI.md#addm), [axpym](BLISObjectAPI.md#axpym), [copym](BLISObjectAPI.md#copym), [scalm](BLISObjectAPI.md#scalm), [scal2m](BLISObjectAPI.md#scal2m), [setm](BLISObjectAPI.md#setm), [setrm](BLISObjectAPI.md#setrm), [setim](BLISObjectAPI.md#setim), [subm](BLISObjectAPI.md#subm) + * **[Level-1f](BLISObjectAPI.md#level-1f-operations)**: Fused operations on multiple vectors: + * [axpy2v](BLISObjectAPI.md#axpy2v), [dotaxpyv](BLISObjectAPI.md#dotaxpyv), [axpyf](BLISObjectAPI.md#axpyf), [dotxf](BLISObjectAPI.md#dotxf), [dotxaxpyf](BLISObjectAPI.md#dotxaxpyf) + * **[Level-2](BLISObjectAPI.md#level-2-operations)**: Operations with one matrix and (at least) one vector operand: + * [gemv](BLISObjectAPI.md#gemv), [ger](BLISObjectAPI.md#ger), [hemv](BLISObjectAPI.md#hemv), [her](BLISObjectAPI.md#her), [her2](BLISObjectAPI.md#her2), [symv](BLISObjectAPI.md#symv), [syr](BLISObjectAPI.md#syr), [syr2](BLISObjectAPI.md#syr2), [trmv](BLISObjectAPI.md#trmv), [trsv](BLISObjectAPI.md#trsv) + * **[Level-3](BLISObjectAPI.md#level-3-operations)**: Operations with matrices that are multiplication-like: + * [gemm](BLISObjectAPI.md#gemm), [hemm](BLISObjectAPI.md#hemm), [herk](BLISObjectAPI.md#herk), [her2k](BLISObjectAPI.md#her2k), [symm](BLISObjectAPI.md#symm), [syrk](BLISObjectAPI.md#syrk), [syr2k](BLISObjectAPI.md#syr2k), [trmm](BLISObjectAPI.md#trmm), [trmm3](BLISObjectAPI.md#trmm3), [trsm](BLISObjectAPI.md#trsm) + * **[Utility](BLISObjectAPI.md#Utility-operations)**: Miscellaneous operations on matrices and vectors: + * [asumv](BLISObjectAPI.md#asumv), [norm1v](BLISObjectAPI.md#norm1v), [normfv](BLISObjectAPI.md#normfv), [normiv](BLISObjectAPI.md#normiv), [norm1m](BLISObjectAPI.md#norm1m), [normfm](BLISObjectAPI.md#normfm), [normim](BLISObjectAPI.md#normim), [mkherm](BLISObjectAPI.md#mkherm), [mksymm](BLISObjectAPI.md#mksymm), [mktrim](BLISObjectAPI.md#mktrim), [fprintv](BLISObjectAPI.md#fprintv), [fprintm](BLISObjectAPI.md#fprintm),[printv](BLISObjectAPI.md#printv), [printm](BLISObjectAPI.md#printm), [randv](BLISObjectAPI.md#randv), [randm](BLISObjectAPI.md#randm), [sumsqv](BLISObjectAPI.md#sumsqv), [getsc](BLISObjectAPI.md#getsc), [getijv](BLISObjectAPI.md#getijv), [getijm](BLISObjectAPI.md#getijm), [setsc](BLISObjectAPI.md#setsc), [setijv](BLISObjectAPI.md#setijv), [setijm](BLISObjectAPI.md#setijm), [eqsc](BLISObjectAPI.md#eqsc), [eqv](BLISObjectAPI.md#eqv), [eqm](BLISObjectAPI.md#eqm) + + + # Introduction This document summarizes one of the primary native APIs in BLIS--the object API. Here, we also discuss BLIS-specific type definitions, header files, and prototypes to auxiliary functions. @@ -40,6 +65,9 @@ There are many functions that BLIS implements that are not listed here, either b The object API was given its name (a) because it abstracts the floating-point types of its operands (along with many other properties) within a `typedef struct {...}` data structure, and (b) to contrast it with the other native API in BLIS, the typed API, which is [documented here](BLISTypedAPI.md). (The third API supported by BLIS is the BLAS compatibility layer, which mimics conventional Fortran-77 BLAS.) +In general, this document should be treated more as a reference than a place to learn how to use BLIS in your application. Thus, we highly encourage all readers to first study the [example code](BLISObjectAPI.md#example-code) provided within the BLIS source distribution. + + ## BLIS types The following tables list various types used throughout the BLIS object API. @@ -53,7 +81,6 @@ The following tables list various types used throughout the BLIS object API. | `dim_t` | `gint_t` | matrix and vector dimensions. | | `inc_t` | `gint_t` | matrix row/column strides and vector increments. | | `doff_t` | `gint_t` | matrix diagonal offset: if _k_ < 0, diagonal begins at element (-_k_,0); otherwise diagonal begins at element (0,_k_). | -| `bool_t` | `gint_t` | boolean values: `TRUE` or `FALSE`. | | `siz_t` | `guint_t` | a byte size or byte offset. | ### Floating-point types @@ -394,9 +421,7 @@ Objects initialized via this function should **never** be passed to `bli_obj_fre Notes for interpreting function descriptions: * Object accessor functions allow the caller to query certain properties of objects. * These functions are only guaranteed to return meaningful values when called upon objects that have been fully initialized/created. - * Many specialized functions are omitted from this section for brevity. For a full list of accessor functions, please see [frame/include/bli_obj_macro_defs.h](https://github.com/flame/blis/tree/master/frame/include/bli_obj_macro_defs.h). - -**Note**: For now, we mostly omit documentation for the corresponding functions used to modify object properties because those functions can easily invalidate the state of an `obj_t` and should be used only in specific instances. If you think you need to manually set the fields of an `obj_t`, please contact BLIS developers so we can give you personalized guidance. + * Many specialized functions are omitted from this section for brevity. For a full list of accessor functions, please see [frame/include/bli_obj_macro_defs.h](https://github.com/flame/blis/tree/master/frame/include/bli_obj_macro_defs.h), though most users will most likely not need methods beyond those documented below. --- @@ -424,7 +449,7 @@ Return the precision component of the storage datatype property of `obj`. ```c trans_t bli_obj_conjtrans_status( obj_t* obj ); ``` -Return the `trans_t` property of `obj`, which may indicate transposition, conjugation, both, or neither. +Return the `trans_t` property of `obj`, which may indicate transposition, conjugation, both, or neither. Thus, possible return values are `BLIS_NO_TRANSPOSE`, `BLIS_CONJ_NO_TRANSPOSE`, `BLIS_TRANSPOSE`, or `BLIS_CONJ_TRANSPOSE`. --- @@ -445,23 +470,30 @@ Thus, possible return values are `BLIS_NO_CONJUGATE` or `BLIS_CONJUGATE`. --- ```c -uplo_t bli_obj_uplo( obj_t* obj ); +struc_t bli_obj_struc( obj_t* obj ); ``` -Return the `uplo_t` property of `obj`. +Return the structure property of `obj`. --- ```c -struc_t bli_obj_struc( obj_t* obj ); +uplo_t bli_obj_uplo( obj_t* obj ); ``` -Return the `struc_t` property of `obj`. +Return the uplo (i.e., storage) property of `obj`. --- ```c diag_t bli_obj_diag( obj_t* obj ); ``` -Return the `diag_t` property of `obj`. +Return the diagonal property of `obj`. + +--- + +```c +doff_t bli_obj_diag_offset( obj_t* obj ); +``` +Return the diagonal offset of `obj`. Note that the diagonal offset will be negative, `-i`, if the diagonal begins at element `(-i,0)` and positive `j` if the diagonal begins at element `(0,j)`. --- @@ -493,13 +525,6 @@ Return the number of columns (or _n_ dimension) of `obj` after taking into accou --- -```c -doff_t bli_obj_diag_offset( obj_t* obj ); -``` -Return the diagonal offset of `obj`. Note that the diagonal offset will be negative, `-i`, if the diagonal begins at element `(-i,0)` and positive `j` if the diagonal begins at element `(0,j)`. - ---- - ```c inc_t bli_obj_row_stride( obj_t* obj ); ``` @@ -543,6 +568,90 @@ siz_t bli_obj_elem_size( obj_t* obj ); ``` Return the size, in bytes, of the storage datatype as indicated by `bli_obj_dt()`. + + +## Object mutator function reference + +Notes for interpreting function descriptions: + * Object mutator functions allow the caller to modify certain properties of objects. + * The user should be extra careful about modifying properties after objects are created. For typical use of these functions, please study the example code provided in [examples/oapi](https://github.com/flame/blis/tree/master/examples/oapi). + * The list of mutators below is much shorter than the list of accessor functions provided in the previous section. Most mutator functions should *not* be called by users (unless you know what you are doing). For a full list of mutator functions, please see [frame/include/bli_obj_macro_defs.h](https://github.com/flame/blis/tree/master/frame/include/bli_obj_macro_defs.h), though most users will most likely not need methods beyond those documented below. + +--- + +```c +void bli_obj_set_conjtrans( trans_t trans, obj_t* obj ); +``` +Set both conjugation and transposition properties of `obj` using the corresponding components of `trans`. + +--- + +```c +void bli_obj_set_onlytrans( trans_t trans, obj_t* obj ); +``` +Set the transposition property of `obj` using the transposition component of `trans`. Leaves the conjugation property of `obj` unchanged. + +--- + +```c +void bli_obj_set_conj( conj_t conj, obj_t* obj ); +``` +Set the conjugation property of `obj` using `conj`. Leaves the transposition property of `obj` unchanged. + +--- + +```c +void bli_obj_apply_trans( trans_t trans, obj_t* obj ); +``` +Apply `trans` to the transposition property of `obj`. For example, applying `BLIS_TRANSPOSE` will toggle the transposition property of `obj` but leave the conjugation property unchanged; applying `BLIS_CONJ_TRANSPOSE` will toggle both the conjugation and transposition properties of `obj`. + +--- + +```c +void bli_obj_apply_conj( conj_t conj, obj_t* obj ); +``` +Apply `conj` to the conjugation property of `obj`. Specifically, applying `BLIS_CONJUGATE` will toggle the conjugation property of `obj`; applying `BLIS_NO_CONJUGATE` will have no effect. Leaves the transposition property of `obj` unchanged. + +--- + +```c +void bli_obj_set_struc( struc_t struc, obj_t* obj ); +``` +Set the structure property of `obj` to `struc`. + +--- + +```c +void bli_obj_set_uplo( uplo_t uplo, obj_t* obj ); +``` +Set the uplo (i.e., storage) property of `obj` to `uplo`. + +--- + +```c +void bli_obj_set_diag( diag_t diag, obj_t* obj ); +``` +Set the diagonal property of `obj` to `diag`. + +--- + +```c +void bli_obj_set_diag_offset( doff_t doff, obj_t* obj ); +``` +Set the diagonal offset property of `obj` to `doff`. Note that `doff_t` may be typecast from any signed integer. + +--- + + +## Other object function reference + +--- + +```c +void bli_obj_induce_trans( obj_t* obj ); +``` +Modify the properties of `obj` to induce a logical transposition. This function operates without regard to whether the transposition property is already set. Therefore, depending on the circumstance, the caller may or may not wish to clear the transposition property after calling this function. + --- ```c @@ -568,13 +677,6 @@ void bli_obj_imag_part( obj_t* c, obj_t* i ); ``` Initialize `i` to be a modified shallow copy of `c` that refers only to the imaginary part of `c`. ---- - -```c -void bli_obj_induce_trans( obj_t* obj ); -``` -Modify the properties of `obj` to induce a logical transposition. This function operations without regard to whether the transposition property is already set. Therefore, depending on the circumstance, the caller may or may not wish to clear the transposition property after calling this function. (If needed, the user may call `bli_obj_toggle_trans( obj )` to toggle the transposition status.) - # Computational function reference @@ -592,26 +694,6 @@ Notes for interpreting function descriptions: --- -## Operation index - - * **[Level-1v](BLISObjectAPI.md#level-1v-operations)**: Operations on vectors: - * [addv](BLISObjectAPI.md#addv), [amaxv](BLISObjectAPI.md#amaxv), [axpyv](BLISObjectAPI.md#axpyv), [axpbyv](BLISObjectAPI.md#axpbyv), [copyv](BLISObjectAPI.md#copyv), [dotv](BLISObjectAPI.md#dotv), [dotxv](BLISObjectAPI.md#dotxv), [invertv](BLISObjectAPI.md#invertv), [scal2v](BLISObjectAPI.md#scal2v), [scalv](BLISObjectAPI.md#scalv), [setv](BLISObjectAPI.md#setv), [setrv](BLISObjectAPI.md#setrv), [setiv](BLISObjectAPI.md#setiv), [subv](BLISObjectAPI.md#subv), [swapv](BLISObjectAPI.md#swapv), [xpbyv](BLISObjectAPI.md#xpbyv) - * **[Level-1d](BLISObjectAPI.md#level-1d-operations)**: Element-wise operations on matrix diagonals: - * [addd](BLISObjectAPI.md#addd), [axpyd](BLISObjectAPI.md#axpyd), [copyd](BLISObjectAPI.md#copyd), [invertd](BLISObjectAPI.md#invertd), [scald](BLISObjectAPI.md#scald), [scal2d](BLISObjectAPI.md#scal2d), [setd](BLISObjectAPI.md#setd), [setid](BLISObjectAPI.md#setid), [shiftd](BLISObjectAPI.md#shiftd), [subd](BLISObjectAPI.md#subd), [xpbyd](BLISObjectAPI.md#xpbyd) - * **[Level-1m](BLISObjectAPI.md#level-1m-operations)**: Element-wise operations on matrices: - * [addm](BLISObjectAPI.md#addm), [axpym](BLISObjectAPI.md#axpym), [copym](BLISObjectAPI.md#copym), [scalm](BLISObjectAPI.md#scalm), [scal2m](BLISObjectAPI.md#scal2m), [setm](BLISObjectAPI.md#setm), [setrm](BLISObjectAPI.md#setrm), [setim](BLISObjectAPI.md#setim), [subm](BLISObjectAPI.md#subm) - * **[Level-1f](BLISObjectAPI.md#level-1f-operations)**: Fused operations on multiple vectors: - * [axpy2v](BLISObjectAPI.md#axpy2v), [dotaxpyv](BLISObjectAPI.md#dotaxpyv), [axpyf](BLISObjectAPI.md#axpyf), [dotxf](BLISObjectAPI.md#dotxf), [dotxaxpyf](BLISObjectAPI.md#dotxaxpyf) - * **[Level-2](BLISObjectAPI.md#level-2-operations)**: Operations with one matrix and (at least) one vector operand: - * [gemv](BLISObjectAPI.md#gemv), [ger](BLISObjectAPI.md#ger), [hemv](BLISObjectAPI.md#hemv), [her](BLISObjectAPI.md#her), [her2](BLISObjectAPI.md#her2), [symv](BLISObjectAPI.md#symv), [syr](BLISObjectAPI.md#syr), [syr2](BLISObjectAPI.md#syr2), [trmv](BLISObjectAPI.md#trmv), [trsv](BLISObjectAPI.md#trsv) - * **[Level-3](BLISObjectAPI.md#level-3-operations)**: Operations with matrices that are multiplication-like: - * [gemm](BLISObjectAPI.md#gemm), [hemm](BLISObjectAPI.md#hemm), [herk](BLISObjectAPI.md#herk), [her2k](BLISObjectAPI.md#her2k), [symm](BLISObjectAPI.md#symm), [syrk](BLISObjectAPI.md#syrk), [syr2k](BLISObjectAPI.md#syr2k), [trmm](BLISObjectAPI.md#trmm), [trmm3](BLISObjectAPI.md#trmm3), [trsm](BLISObjectAPI.md#trsm) - * **[Utility](BLISObjectAPI.md#Utility-operations)**: Miscellaneous operations on matrices and vectors: - * [asumv](BLISObjectAPI.md#asumv), [norm1v](BLISObjectAPI.md#norm1v), [normfv](BLISObjectAPI.md#normfv), [normiv](BLISObjectAPI.md#normiv), [norm1m](BLISObjectAPI.md#norm1m), [normfm](BLISObjectAPI.md#normfm), [normim](BLISObjectAPI.md#normim), [mkherm](BLISObjectAPI.md#mkherm), [mksymm](BLISObjectAPI.md#mksymm), [mktrim](BLISObjectAPI.md#mktrim), [fprintv](BLISObjectAPI.md#fprintv), [fprintm](BLISObjectAPI.md#fprintm),[printv](BLISObjectAPI.md#printv), [printm](BLISObjectAPI.md#printm), [randv](BLISObjectAPI.md#randv), [randm](BLISObjectAPI.md#randm), [sumsqv](BLISObjectAPI.md#sumsqv), [getijm](BLISObjectAPI.md#getijm), [setijm](BLISObjectAPI.md#setijm) - ---- - - ## Level-1v operations Level-1v operations perform various level-1 BLAS-like operations on vectors (hence the _v_). @@ -708,6 +790,8 @@ Perform ``` where `x` and `y` are vectors of length _n_. +Observed object properties: `conj?(x)`. + --- #### dotv @@ -725,6 +809,8 @@ Perform ``` where `x` and `y` are vectors of length _n_, and `rho` is a scalar. +Observed object properties: `conj?(x)`, `conj?(y)`. + --- #### dotxv @@ -744,6 +830,8 @@ Perform ``` where `x` and `y` are vectors of length _n_, and `alpha`, `beta`, and `rho` are scalars. +Observed object properties: `conj?(alpha)`, `conj?(beta)`, `conj?(x)`, `conj?(y)`. + --- #### invertv @@ -997,7 +1085,7 @@ void bli_setd ); ``` -Observed object properties: `conj?(alpha)`, `diagoff(A)`, `diag(A)`. +Observed object properties: `conj?(alpha)`, `diagoff(A)`. --- @@ -1600,6 +1688,27 @@ Observed object properties: `trans?(A)`, `trans?(B)`. --- +#### gemmt +```c +void bli_gemmt + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ); +``` +Perform +``` + C := beta * C + alpha * trans?(A) * trans?(B) +``` +where `C` is an _m x m_ matrix, `trans?(A)` is an _m x k_ matrix, and `trans?(B)` is a _k x m_ matrix. This operation is similar to `bli_gemm()` except that it only updates the lower or upper triangle of `C` as specified by `uplo(C)`. + +Observed object properties: `trans?(A)`, `trans?(B)`, `uplo(C)`. + +--- + #### hemm ```c void bli_hemm @@ -2022,6 +2131,34 @@ where, on entry, `scale` and `sumsq` contain `scale_old` and `sumsq_old`, respec --- +#### getsc +```c +void bli_getsc + ( + obj_t* chi, + double* zeta_r, + double* zeta_i + ) +``` +Copy the real and imaginary values from the scalar object `chi` to `zeta_r` and `zeta_i`. If `chi` is stored as a real type, then `zeta_i` is set to zero. (If `chi` is stored in single precision, the corresponding elements are typecast/promoted during the copy.) + +--- + +#### getijv +```c +err_t bli_getijv + ( + dim_t i, + obj_t* b, + double* ar, + double* ai + ) +``` +Copy the real and imaginary values at the `i`th element of vector object `x` to `ar` and `ai`. If elements of `x` are stored as real types, then only `ar` is overwritten and `ai` is left unchanged. (If `x` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +If either the element offset `i` is beyond the vector dimension of `x` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `x` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- + #### getijm ```c err_t bli_getijm @@ -2033,8 +2170,38 @@ err_t bli_getijm double* ai ) ``` -Copy the real and imaginary values at the (`i`,`j`) element of object `b` to `ar` and `ai`. f elements of `b` are stored as real types, then only `ar` is overwritten and `ai` is left unchanged. (If `b` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) -If either the row offset `i` is beyond the _m_ dimension of `b`, or column offset `j` is beyond the _n_ dimension of `b`, the function does not perform any copy and returns `BLIS_FAILURE`. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, `BLIS_FAILURE` is returned. +Copy the real and imaginary values at the (`i`,`j`) element of object `b` to `ar` and `ai`. If elements of `b` are stored as real types, then only `ar` is overwritten and `ai` is left unchanged. (If `b` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +If either the row offset `i` is beyond the _m_ dimension of `b` or less than zero, or column offset `j` is beyond the _n_ dimension of `b` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- + +#### setsc +```c +void bli_setsc + ( + double* zeta_r, + double* zeta_i, + obj_t* chi + ); +``` +Copy real and imaginary values `zeta_r` and `zeta_i` to the scalar object `chi`. If `chi` is stored as a real type, then `zeta_i` is ignored. (If `chi` is stored in single precision, the contents are typecast/demoted during the copy.) + +--- + +#### setijv +```c +err_t bli_setijv + ( + double ar, + double ai, + dim_t i, + obj_t* x + ); +``` +Copy real and imaginary values `ar` and `ai` to the `i`th element of vector object `x`. If elements of `x` are stored as real types, then only `ar` is copied and `ai` is ignored. (If `x` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) +If the element offset `i` is beyond the vector dimension of `x` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `x` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- #### setijm ```c @@ -2048,7 +2215,59 @@ err_t bli_setijm ); ``` Copy real and imaginary values `ar` and `ai` to the (`i`,`j`) element of object `b`. If elements of `b` are stored as real types, then only `ar` is copied and `ai` is ignored. (If `b` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) -If either the row offset `i` is beyond the _m_ dimension of `b`, or column offset `j` is beyond the _n_ dimension of `b`, the function does not perform any copy and returns `BLIS_FAILURE`. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, `BLIS_FAILURE` is returned. +If either the row offset `i` is beyond the _m_ dimension of `b` or less than zero, or column offset `j` is beyond the _n_ dimension of `b` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- + +#### eqsc +```c +void bli_eqsc + ( + obj_t chi, + obj_t psi, + bool* is_eq + ); +``` +Perform an element-wise comparison between scalars `chi` and `psi` and store the boolean result in the `bool` pointed to by `is_eq`. +If exactly one of `conj(chi)` or `conj(psi)` (but not both) indicate a conjugation, then one of the scalars will be implicitly conjugated for purposes of the comparision. + +Observed object properties: `conj?(chi)`, `conj?(psi)`. + +--- + +#### eqv +```c +void bli_eqv + ( + obj_t x, + obj_t y, + bool* is_eq + ); +``` +Perform an element-wise comparison between vectors `x` and `y` and store the boolean result in the `bool` pointed to by `is_eq`. +If exactly one of `conj(x)` or `conj(y)` (but not both) indicate a conjugation, then one of the vectors will be implicitly conjugated for purposes of the comparision. + +Observed object properties: `conj?(x)`, `conj?(y)`. + +--- + +#### eqm +```c +void bli_eqm + ( + obj_t a, + obj_t b, + bool* is_eq + ); +``` +Perform an element-wise comparison between matrices `A` and `B` and store the boolean result in the `bool` pointed to by `is_eq`. +Here, `A` is stored as a dense matrix, or lower- or upper-triangular/trapezoidal matrix with arbitrary diagonal offset and unit or non-unit diagonal. +If `diag(A)` indicates a unit diagonal, the diagonals of both matrices will be ignored for purposes of the comparision. +If `uplo(A)` indicates lower or upper storage, only that part of both matrices `A` and `B` will be referenced. +If exactly one of `trans(A)` or `trans(B)` (but not both) indicate a transposition, then one of the matrices will be transposed for purposes of the comparison. +Similarly, if exactly one of `trans(A)` or `trans(B)` (but not both) indicate a conjugation, then one of the matrices will be implicitly conjugated for purposes of the comparision. + +Observed object properties: `diagoff(A)`, `diag(A)`, `uplo(A)`, `trans?(A)`, `trans?(B)`. @@ -2117,23 +2336,64 @@ char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ) ``` Possible implementation (ie: the `ind_t method` argument) types are: - * `BLIS_3MH`: Implementation based on the 3m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_3M1`: Implementation based on the 3m method applied within the 1st loop around the microkernel. - * `BLIS_4MH`: Implementation based on the 4m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_4M1B`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that the 1st loop is fissured into two loops, the first of which multiplies the real part of the current micropanel of packed matrix B (against all real and imaginary parts of packed matrix A), and the second of which multiplies the imaginary part of the current micropanel of packed matrix B. - * `BLIS_4M1A`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that real and imaginary components of the current micropanels are completely used before proceeding to the next virtual microkernel invocation. * `BLIS_1M`: Implementation based on the 1m method. (This is the default induced method when real domain kernels are present but complex kernels are missing.) * `BLIS_NAT`: Implementation based on "native" execution (ie: NOT an induced method). -**NOTE**: `BLIS_3M3` and `BLIS_3M2` have been deprecated from the `typedef enum` of `ind_t`, and `BLIS_4M1B` is also effectively no longer available, though the `typedef enum` value still exists. - Possible microkernel types (ie: the return values for `bli_info_get_*_ukr_impl_string()`) are: * `BLIS_REFERENCE_UKERNEL` (`"refrnce"`): This value is returned when the queried microkernel is provided by the reference implementation. * `BLIS_VIRTUAL_UKERNEL` (`"virtual"`): This value is returned when the queried microkernel is driven by a the "virtual" microkernel provided by an induced method. This happens for any `method` value that is not `BLIS_NAT` (ie: native), but only applies to the complex domain. * `BLIS_OPTIMIZED_UKERNEL` (`"optimzd"`): This value is returned when the queried microkernel is provided by an implementation that is neither reference nor virtual, and thus we assume the kernel author would deem it to be "optimized". Such a microkernel may not be optimal in the literal sense of the word, but nonetheless is _intended_ to be optimized, at least relative to the reference microkernels. * `BLIS_NOTAPPLIC_UKERNEL` (`"notappl"`): This value is returned usually when performing a `gemmtrsm` or `trsm` microkernel type query for any `method` value that is not `BLIS_NAT` (ie: native). That is, induced methods cannot be (purely) used on `trsm`-based microkernels because these microkernels perform more a triangular inversion, which is not matrix multiplication. + +## Clock functions + +--- + +#### clock +```c +double bli_clock + ( + void + ); +``` +Return the amount of time that has elapsed since some fixed time in the past. The return values of `bli_clock()` typically feature nanosecond precision, though this is not guaranteed. + +**Note:** On Linux, `bli_clock()` is implemented in terms of `clock_gettime()` using the `clockid_t` value of `CLOCK_MONOTONIC`. On OS X, `bli_clock` is implemented in terms of `mach_absolute_time()`. And on Windows, `bli_clock` is implemented in terms of `QueryPerformanceFrequency()`. Please see [frame/base/bli_clock.c](https://github.com/flame/blis/blob/master/frame/base/bli_clock.c) for more details. +**Note:** This function is returns meaningless values when BLIS is configured with `--disable-system`. + +--- + +#### clock_min_diff +```c +double bli_clock_min_diff + ( + double time_prev_min, + double time_start + ); +``` +This function computes an intermediate value, `time_diff`, equal to `bli_clock() - time_start`, and then tentatively prepares to return the minimum value of `time_diff` and `time_min`. If that minimum value is extremely small (close to zero), the function returns `time_min` instead. + +This function is meant to be used in conjuction with `bli_clock()` for +performance timing within applications--specifically in loops where only +the fastest timing is of interest. For example: +```c +double t_save = DBL_MAX; +for( i = 0; i < 3; ++i ) +{ + double t = bli_clock(); + bli_gemm( ... ); + t_save = bli_clock_min_diff( t_save, t ); +} +double gflops = ( 2.0 * m * k * n ) / ( t_save * 1.0e9 ); +``` +This code calls `bli_gemm()` three times and computes the performance, in GFLOPS, of the fastest of the three executions. + +--- + + + # Example code -BLIS provides lots of example code in the [examples/oapi](https://github.com/flame/blis/tree/master/examples/oapi) directory of the BLIS source distribution. The example code in this directory is set up like a tutorial, and so we recommend starting from the beginning. Topics include creating and managing objects, printing vectors and matrices, setting and querying object properties, and calling a representative subset of the computational level-1v, -1m, -2, -3, and utility operations documented above. +BLIS provides lots of example code in the [examples/oapi](https://github.com/flame/blis/tree/master/examples/oapi) directory of the BLIS source distribution. The example code in this directory is set up like a tutorial, and so we recommend starting from the beginning. Topics include creating and managing objects, printing vectors and matrices, setting and querying object properties, and calling a representative subset of the computational level-1v, -1m, -2, -3, and utility operations documented above. Please read the `README` contained within the `examples/oapi` directory for further details. diff --git a/docs/BLISTypedAPI.md b/docs/BLISTypedAPI.md index c20da0fa17..76d7ef8f63 100644 --- a/docs/BLISTypedAPI.md +++ b/docs/BLISTypedAPI.md @@ -1,6 +1,7 @@ # Contents * **[Contents](BLISTypedAPI.md#contents)** +* **[Operation index](BLISTypedAPI.md#operation-index)** * **[Introduction](BLISTypedAPI.md#introduction)** * [BLIS types](BLISTypedAPI.md#blis-types) * [Integer-based types](BLISTypedAPI.md#integer-based-types) @@ -12,7 +13,6 @@ * [BLIS header file](BLISTypedAPI.md#blis-header-file) * [Initialization and cleanup](BLISTypedAPI.md#initialization-and-cleanup) * **[Computational function reference](BLISTypedAPI.md#computational-function-reference)** - * [Operation index](BLISTypedAPI.md#operation-index) * [Level-1v operations](BLISTypedAPI.md#level-1v-operations) * [Level-1d operations](BLISTypedAPI.md#level-1d-operations) * [Level-1m operations](BLISTypedAPI.md#level-1m-operations) @@ -26,8 +26,32 @@ * [Specific configuration](BLISTypedAPI.md#specific-configuration) * [General configuration](BLISTypedAPI.md#general-configuration) * [Kernel information](BLISTypedAPI.md#kernel-information) + * [Clock functions](BLISTypedAPI.md#clock-functions) * **[Example code](BLISTypedAPI.md#example-code)** + + +# Operation index + +This index provides a quick way to jump directly to the description for each operation discussed later in the [Computational function reference](BLISTypedAPI.md#computational-function-reference) section: + + * **[Level-1v](BLISTypedAPI.md#level-1v-operations)**: Operations on vectors: + * [addv](BLISTypedAPI.md#addv), [amaxv](BLISTypedAPI.md#amaxv), [axpyv](BLISTypedAPI.md#axpyv), [axpbyv](BLISTypedAPI.md#axpbyv), [copyv](BLISTypedAPI.md#copyv), [dotv](BLISTypedAPI.md#dotv), [dotxv](BLISTypedAPI.md#dotxv), [invertv](BLISTypedAPI.md#invertv), [scal2v](BLISTypedAPI.md#scal2v), [scalv](BLISTypedAPI.md#scalv), [setv](BLISTypedAPI.md#setv), [subv](BLISTypedAPI.md#subv), [swapv](BLISTypedAPI.md#swapv), [xpbyv](BLISTypedAPI.md#xpbyv) + * **[Level-1d](BLISTypedAPI.md#level-1d-operations)**: Element-wise operations on matrix diagonals: + * [addd](BLISTypedAPI.md#addd), [axpyd](BLISTypedAPI.md#axpyd), [copyd](BLISTypedAPI.md#copyd), [invertd](BLISTypedAPI.md#invertd), [scald](BLISTypedAPI.md#scald), [scal2d](BLISTypedAPI.md#scal2d), [setd](BLISTypedAPI.md#setd), [setid](BLISTypedAPI.md#setid), [shiftd](BLISTypedAPI.md#shiftd), [subd](BLISTypedAPI.md#subd), [xpbyd](BLISTypedAPI.md#xpbyd) + * **[Level-1m](BLISTypedAPI.md#level-1m-operations)**: Element-wise operations on matrices: + * [addm](BLISTypedAPI.md#addm), [axpym](BLISTypedAPI.md#axpym), [copym](BLISTypedAPI.md#copym), [scalm](BLISTypedAPI.md#scalm), [scal2m](BLISTypedAPI.md#scal2m), [setm](BLISTypedAPI.md#setm), [subm](BLISTypedAPI.md#subm) + * **[Level-1f](BLISTypedAPI.md#level-1f-operations)**: Fused operations on multiple vectors: + * [axpy2v](BLISTypedAPI.md#axpy2v), [dotaxpyv](BLISTypedAPI.md#dotaxpyv), [axpyf](BLISTypedAPI.md#axpyf), [dotxf](BLISTypedAPI.md#dotxf), [dotxaxpyf](BLISTypedAPI.md#dotxaxpyf) + * **[Level-2](BLISTypedAPI.md#level-2-operations)**: Operations with one matrix and (at least) one vector operand: + * [gemv](BLISTypedAPI.md#gemv), [ger](BLISTypedAPI.md#ger), [hemv](BLISTypedAPI.md#hemv), [her](BLISTypedAPI.md#her), [her2](BLISTypedAPI.md#her2), [symv](BLISTypedAPI.md#symv), [syr](BLISTypedAPI.md#syr), [syr2](BLISTypedAPI.md#syr2), [trmv](BLISTypedAPI.md#trmv), [trsv](BLISTypedAPI.md#trsv) + * **[Level-3](BLISTypedAPI.md#level-3-operations)**: Operations with matrices that are multiplication-like: + * [gemm](BLISTypedAPI.md#gemm), [hemm](BLISTypedAPI.md#hemm), [herk](BLISTypedAPI.md#herk), [her2k](BLISTypedAPI.md#her2k), [symm](BLISTypedAPI.md#symm), [syrk](BLISTypedAPI.md#syrk), [syr2k](BLISTypedAPI.md#syr2k), [trmm](BLISTypedAPI.md#trmm), [trmm3](BLISTypedAPI.md#trmm3), [trsm](BLISTypedAPI.md#trsm) + * **[Utility](BLISTypedAPI.md#Utility-operations)**: Miscellaneous operations on matrices and vectors: + * [asumv](BLISTypedAPI.md#asumv), [norm1v](BLISTypedAPI.md#norm1v), [normfv](BLISTypedAPI.md#normfv), [normiv](BLISTypedAPI.md#normiv), [norm1m](BLISTypedAPI.md#norm1m), [normfm](BLISTypedAPI.md#normfm), [normim](BLISTypedAPI.md#normim), [mkherm](BLISTypedAPI.md#mkherm), [mksymm](BLISTypedAPI.md#mksymm), [mktrim](BLISTypedAPI.md#mktrim), [fprintv](BLISTypedAPI.md#fprintv), [fprintm](BLISTypedAPI.md#fprintm),[printv](BLISTypedAPI.md#printv), [printm](BLISTypedAPI.md#printm), [randv](BLISTypedAPI.md#randv), [randm](BLISTypedAPI.md#randm), [sumsqv](BLISTypedAPI.md#sumsqv), [getsc](BLISTypedAPI.md#getsc), [getijv](BLISTypedAPI.md#getijv), [getijm](BLISTypedAPI.md#getijm), [setsc](BLISTypedAPI.md#setsc), [setijv](BLISTypedAPI.md#setijv), [setijm](BLISTypedAPI.md#setijm), [eqsc](BLISTypedAPI.md#eqsc), [eqv](BLISTypedAPI.md#eqv), [eqm](BLISTypedAPI.md#eqm) + + + # Introduction This document summarizes one of the primary native APIs in BLIS--the "typed" API. Here, we also discuss BLIS-specific type definitions, header files, and prototypes to auxiliary functions. This document also includes APIs to key kernels which are used to accelerate and optimize various level-2 and level-3 operations, though the [Kernels Guide](KernelsHowTo.md) goes into more detail, especially for level-3 microkernels. @@ -36,6 +60,8 @@ There are many functions that BLIS implements that are not listed here, either b For curious readers, the typed API was given its name (a) because it exposes the floating-point types in the names of its functions, and (b) to contrast it with the other native API in BLIS, the object API, which is [documented here](BLISObjectAPI.md). (The third API supported by BLIS is the BLAS compatibility layer, which mimics conventional Fortran-77 BLAS.) +In general, this document should be treated more as a reference than a place to learn how to use BLIS in your application. Thus, we highly encourage all readers to first study the [example code](BLISTypedAPI.md#example-code) provided within the BLIS source distribution. + ## BLIS types The following tables list various types used throughout the BLIS typed API. @@ -190,26 +216,6 @@ Notes for interpreting function descriptions: --- -## Operation index - - * **[Level-1v](BLISTypedAPI.md#level-1v-operations)**: Operations on vectors: - * [addv](BLISTypedAPI.md#addv), [amaxv](BLISTypedAPI.md#amaxv), [axpyv](BLISTypedAPI.md#axpyv), [axpbyv](BLISTypedAPI.md#axpbyv), [copyv](BLISTypedAPI.md#copyv), [dotv](BLISTypedAPI.md#dotv), [dotxv](BLISTypedAPI.md#dotxv), [invertv](BLISTypedAPI.md#invertv), [scal2v](BLISTypedAPI.md#scal2v), [scalv](BLISTypedAPI.md#scalv), [setv](BLISTypedAPI.md#setv), [subv](BLISTypedAPI.md#subv), [swapv](BLISTypedAPI.md#swapv), [xpbyv](BLISTypedAPI.md#xpbyv) - * **[Level-1d](BLISTypedAPI.md#level-1d-operations)**: Element-wise operations on matrix diagonals: - * [addd](BLISTypedAPI.md#addd), [axpyd](BLISTypedAPI.md#axpyd), [copyd](BLISTypedAPI.md#copyd), [invertd](BLISTypedAPI.md#invertd), [scald](BLISTypedAPI.md#scald), [scal2d](BLISTypedAPI.md#scal2d), [setd](BLISTypedAPI.md#setd), [setid](BLISTypedAPI.md#setid), [shiftd](BLISObjectAPI.md#shiftd), [subd](BLISTypedAPI.md#subd), [xpbyd](BLISTypedAPI.md#xpbyd) - * **[Level-1m](BLISTypedAPI.md#level-1m-operations)**: Element-wise operations on matrices: - * [addm](BLISTypedAPI.md#addm), [axpym](BLISTypedAPI.md#axpym), [copym](BLISTypedAPI.md#copym), [scalm](BLISTypedAPI.md#scalm), [scal2m](BLISTypedAPI.md#scal2m), [setm](BLISTypedAPI.md#setm), [subm](BLISTypedAPI.md#subm) - * **[Level-1f](BLISTypedAPI.md#level-1f-operations)**: Fused operations on multiple vectors: - * [axpy2v](BLISTypedAPI.md#axpy2v), [dotaxpyv](BLISTypedAPI.md#dotaxpyv), [axpyf](BLISTypedAPI.md#axpyf), [dotxf](BLISTypedAPI.md#dotxf), [dotxaxpyf](BLISTypedAPI.md#dotxaxpyf) - * **[Level-2](BLISTypedAPI.md#level-2-operations)**: Operations with one matrix and (at least) one vector operand: - * [gemv](BLISTypedAPI.md#gemv), [ger](BLISTypedAPI.md#ger), [hemv](BLISTypedAPI.md#hemv), [her](BLISTypedAPI.md#her), [her2](BLISTypedAPI.md#her2), [symv](BLISTypedAPI.md#symv), [syr](BLISTypedAPI.md#syr), [syr2](BLISTypedAPI.md#syr2), [trmv](BLISTypedAPI.md#trmv), [trsv](BLISTypedAPI.md#trsv) - * **[Level-3](BLISTypedAPI.md#level-3-operations)**: Operations with matrices that are multiplication-like: - * [gemm](BLISTypedAPI.md#gemm), [hemm](BLISTypedAPI.md#hemm), [herk](BLISTypedAPI.md#herk), [her2k](BLISTypedAPI.md#her2k), [symm](BLISTypedAPI.md#symm), [syrk](BLISTypedAPI.md#syrk), [syr2k](BLISTypedAPI.md#syr2k), [trmm](BLISTypedAPI.md#trmm), [trmm3](BLISTypedAPI.md#trmm3), [trsm](BLISTypedAPI.md#trsm) - * **[Utility](BLISTypedAPI.md#Utility-operations)**: Miscellaneous operations on matrices and vectors: - * [asumv](BLISTypedAPI.md#asumv), [norm1v](BLISTypedAPI.md#norm1v), [normfv](BLISTypedAPI.md#normfv), [normiv](BLISTypedAPI.md#normiv), [norm1m](BLISTypedAPI.md#norm1m), [normfm](BLISTypedAPI.md#normfm), [normim](BLISTypedAPI.md#normim), [mkherm](BLISTypedAPI.md#mkherm), [mksymm](BLISTypedAPI.md#mksymm), [mktrim](BLISTypedAPI.md#mktrim), [fprintv](BLISTypedAPI.md#fprintv), [fprintm](BLISTypedAPI.md#fprintm),[printv](BLISTypedAPI.md#printv), [printm](BLISTypedAPI.md#printm), [randv](BLISTypedAPI.md#randv), [randm](BLISTypedAPI.md#randm), [sumsqv](BLISTypedAPI.md#sumsqv) - ---- - - ## Level-1v operations Level-1v operations perform various level-1 BLAS-like operations on vectors (hence the _v_). @@ -845,7 +851,7 @@ void bli_?axpy2v ``` Perform ``` - y := y + alphax * conjx(x) + alphay * conjy(y) + z := y + alphax * conjx(x) + alphay * conjy(y) ``` where `x`, `y`, and `z` are vectors of length _m_. The kernel, if optimized, is implemented as a fused pair of calls to [axpyv](BLISTypedAPI.md#axpyv). @@ -1051,6 +1057,7 @@ void bli_?her2 ( uplo_t uploa, conj_t conjx, + conj_t conjy, dim_t m, ctype* alpha, ctype* x, inc_t incx, @@ -1115,6 +1122,7 @@ void bli_?syr2 ( uplo_t uploa, conj_t conjx, + conj_t conjy, dim_t m, ctype* alpha, ctype* x, inc_t incx, @@ -1206,6 +1214,30 @@ where C is an _m x n_ matrix, `transa(A)` is an _m x k_ matrix, and `transb(B)` --- +#### gemmt +```c +void bli_?gemmt + ( + uplo_t uploc, + trans_t transa, + trans_t transb, + dim_t m, + dim_t k, + ctype* alpha, + ctype* a, inc_t rsa, inc_t csa, + ctype* b, inc_t rsb, inc_t csb, + ctype* beta, + ctype* c, inc_t rsc, inc_t csc + ); +``` +Perform +``` + C := beta * C + alpha * transa(A) * transb(B) +``` +where C is an _m x m_ matrix, `transa(A)` is an _m x k_ matrix, and `transb(B)` is a _k x m_ matrix. This operation is similar to `bli_?gemm()` except that it only updates the lower or upper triangle of `C` as specified by `uploc`. + +--- + #### hemm ```c void bli_?hemm @@ -1264,7 +1296,8 @@ where C is an _m x m_ Hermitian matrix stored in the lower or upper triangle as void bli_?her2k ( uplo_t uploc, - trans_t transab, + trans_t transa, + trans_t transb, dim_t m, dim_t k, ctype* alpha, @@ -1276,9 +1309,9 @@ void bli_?her2k ``` Perform ``` - C := beta * C + alpha * transab(A) * transab(B)^H + conj(alpha) * transab(B) * transab(A)^H + C := beta * C + alpha * transa(A) * transb(B)^H + conj(alpha) * transb(B) * transa(A)^H ``` -where C is an _m x m_ Hermitian matrix stored in the lower or upper triangle as specified by `uploc` and `transab(A)` and `transab(B)` are _m x k_ matrices. +where C is an _m x m_ Hermitian matrix stored in the lower or upper triangle as specified by `uploc` and `transa(A)` and `transb(B)` are _m x k_ matrices. **Note:** The floating-point type of `beta` is always the real projection of the floating-point types of `A` and `C`. @@ -1340,7 +1373,8 @@ where C is an _m x m_ symmetric matrix stored in the lower or upper triangle as void bli_?syr2k ( uplo_t uploc, - trans_t transab, + trans_t transa, + trans_t transb, dim_t m, dim_t k, ctype* alpha, @@ -1352,9 +1386,9 @@ void bli_?syr2k ``` Perform ``` - C := beta * C + alpha * transab(A) * transab(B)^T + alpha * transab(B) * transab(A)^T + C := beta * C + alpha * transa(A) * transb(B)^T + alpha * transb(B) * transa(A)^T ``` -where C is an _m x m_ symmetric matrix stored in the lower or upper triangle as specified by `uploa` and `transab(A)` and `transab(B)` are _m x k_ matrices. +where C is an _m x m_ symmetric matrix stored in the lower or upper triangle as specified by `uploa` and `transa(A)` and `transb(B)` are _m x k_ matrices. --- @@ -1661,6 +1695,149 @@ where, on entry, `scale` and `sumsq` contain `scale_old` and `sumsq_old`, respec --- +#### getsc +```c +void bli_getsc + ( + ctype* chi, + double* zeta_r, + double* zeta_i + ) +``` +Copy the real and imaginary values from the scalar object `chi` to `zeta_r` and `zeta_i`. If `chi` is stored as a real type, then `zeta_i` is set to zero. (If `chi` is stored in single precision, the corresponding elements are typecast/promoted during the copy.) + +--- + +#### getijv +```c +err_t bli_?getijv + ( + dim_t i, + ctype* x, incx, + double* ar, + double* ai + ) +``` +Copy the real and imaginary values at the `i`th element of vector `x` to `ar` and `ai`. For real domain invocations, only `ar` is overwritten and `ai` is left unchanged. (If `x` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +Note that the object-based analogue of [getijv](BLISObjectAPI.md#getijv) does bounds checking of the vector element offset `i` against the vector length while the typed functions specified above do not (since the vector length is not given). + +--- + +#### getijm +```c +err_t bli_?getijm + ( + dim_t i, + dim_t j, + ctype* b, inc_t rs_b, inc_t cs_b, + double* ar, + double* ai + ) +``` +Copy the real and imaginary values at the (`i`,`j`) element of object `b` to `ar` and `ai`. For real domain invocations, only `ar` is overwritten and `ai` is left unchanged. (If `b` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +Note that the object-based analogue of [getijm](BLISObjectAPI.md#getijm) does bounds checking of the matrix element offsets (`i`,`j`) against the matrix dimensions while the typed functions specified above do not (since the matrix dimensions are not given). + +--- + +#### setsc +```c +void bli_setsc + ( + double* zeta_r, + double* zeta_i, + ctype* chi + ); +``` +Copy real and imaginary values `zeta_r` and `zeta_i` to the scalar object `chi`. If `chi` is stored as a real type, then `zeta_i` is ignored. (If `chi` is stored in single precision, the contents are typecast/demoted during the copy.) + +--- + +#### setijv +```c +err_t bli_?setijv + ( + double ar, + double ai, + dim_t i, + ctype* x, incx + ); +``` +Copy real and imaginary values `ar` and `ai` to the `i`th element of vector object `x`. For real domain invocations, only `ar` is copied and `ai` is ignored. (If `x` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) +Note that the object-based analogue of [setijv](BLISObjectAPI.md#setijv) does bounds checking of the vector element offset `i` against the vector length while the typed functions specified above do not (since the vector length is not given). + +--- + +#### setijm +```c +err_t bli_?setijm + ( + double ar, + double ai, + dim_t i, + dim_t j, + ctype* b, inc_t rs_b, inc_t cs_b + ); +``` +Copy real and imaginary values `ar` and `ai` to the (`i`,`j`) element of object `b`. For real domain invocations, only `ar` is copied and `ai` is ignored. (If `b` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) +Note that the object-based analogue of [setijm](BLISObjectAPI.md#setijm) does bounds checking of the matrix element offsets (`i`,`j`) against the matrix dimensions while the typed functions specified above do not (since the matrix dimensions are not given). + +--- + +#### eqsc +```c +void bli_?eqsc + ( + conj_t conjchi, + ctype* chi, + ctype* psi, + bool* is_eq + ); +``` +Perform an element-wise comparison between scalars `chi` and `psi` and store the boolean result in the `bool` pointed to by `is_eq`. +If `conjchi` indicates a conjugation, `chi` will be implicitly conjugated for purposes of the comparision. + +--- + +#### eqv +```c +void bli_?eqv + ( + conj_t conjx, + dim_t n, + ctype* x, inc_t incx, + ctype* y, inc_t incy, + bool* is_eq + ); +``` +Perform an element-wise comparison between length _n_ vectors `x` and `y` and store the boolean result in the `bool` pointed to by `is_eq`. +If `conjx` indicates a conjugation, `x` will be implicitly conjugated for purposes of the comparision. + +--- + +#### eqm +```c +void bli_?eqm + ( + doff_t diagoffa, + diag_t diaga, + uplo_t uploa, + trans_t transa, + dim_t m, + dim_t n, + ctype* a, inc_t rs_a, inc_t cs_a, + ctype* b, inc_t rs_b, inc_t cs_b, + bool* is_eq + ) +``` +Perform an element-wise comparison between matrices `A` and `B` and store the boolean result in the `bool` pointed to by `is_eq`. +Here, `B` is an _m x n_ matrix, `A` is stored as a dense matrix, or lower- or upper-triangular/trapezoidal matrix with arbitrary diagonal offset and unit or non-unit diagonal. +If `diaga` indicates a unit diagonal, the diagonals of both matrices will be ignored for purposes of the comparision. +If `uploa` indicates lower or upper storage, only that part of matrix `A` will be referenced in the comparison. +If `transa` indicates a conjugation and/or transposition, then `A` will be conjugated and/or transposed for purposes of the comparison. + + + + ## Level-3 microkernels @@ -1838,16 +2015,9 @@ char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ) ``` Possible implementation (ie: the `ind_t method` argument) types are: - * `BLIS_3MH`: Implementation based on the 3m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_3M1`: Implementation based on the 3m method applied within the 1st loop around the microkernel. - * `BLIS_4MH`: Implementation based on the 4m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_4M1B`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that the 1st loop is fissured into two loops, the first of which multiplies the real part of the current micropanel of packed matrix B (against all real and imaginary parts of packed matrix A), and the second of which multiplies the imaginary part of the current micropanel of packed matrix B. - * `BLIS_4M1A`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that real and imaginary components of the current micropanels are completely used before proceeding to the next virtual microkernel invocation. * `BLIS_1M`: Implementation based on the 1m method. (This is the default induced method when real domain kernels are present but complex kernels are missing.) * `BLIS_NAT`: Implementation based on "native" execution (ie: NOT an induced method). -**NOTE**: `BLIS_3M3` and `BLIS_3M2` have been deprecated from the `typedef enum` of `ind_t`, and `BLIS_4M1B` is also effectively no longer available, though the `typedef enum` value still exists. - Possible microkernel types (ie: the return values for `bli_info_get_*_ukr_impl_string()`) are: * `BLIS_REFERENCE_UKERNEL` (`"refrnce"`): This value is returned when the queried microkernel is provided by the reference implementation. * `BLIS_VIRTUAL_UKERNEL` (`"virtual"`): This value is returned when the queried microkernel is driven by a the "virtual" microkernel provided by an induced method. This happens for any `method` value that is not `BLIS_NAT` (ie: native), but only applies to the complex domain. @@ -1871,7 +2041,55 @@ char* bli_info_get_trmm3_impl_string( num_t dt ); char* bli_info_get_trsm_impl_string( num_t dt ); ``` + +## Clock functions + +--- + +#### clock +```c +double bli_clock + ( + void + ); +``` +Return the amount of time that has elapsed since some fixed time in the past. The return values of `bli_clock()` typically feature nanosecond precision, though this is not guaranteed. + +**Note:** On Linux, `bli_clock()` is implemented in terms of `clock_gettime()` using the `clockid_t` value of `CLOCK_MONOTONIC`. On OS X, `bli_clock` is implemented in terms of `mach_absolute_time()`. And on Windows, `bli_clock` is implemented in terms of `QueryPerformanceFrequency()`. Please see [frame/base/bli_clock.c](https://github.com/flame/blis/blob/master/frame/base/bli_clock.c) for more details. +**Note:** This function is returns meaningless values when BLIS is configured with `--disable-system`. + +--- + +#### clock_min_diff +```c +double bli_clock_min_diff + ( + double time_prev_min, + double time_start + ); +``` +This function computes an intermediate value, `time_diff`, equal to `bli_clock() - time_start`, and then tentatively prepares to return the minimum value of `time_diff` and `time_min`. If that minimum value is extremely small (close to zero), the function returns `time_min` instead. + +This function is meant to be used in conjuction with `bli_clock()` for +performance timing within applications--specifically in loops where only +the fastest timing is of interest. For example: +```c +double t_save = DBL_MAX; +for( i = 0; i < 3; ++i ) +{ + double t = bli_clock(); + bli_gemm( ... ); + t_save = bli_clock_min_diff( t_save, t ); +} +double gflops = ( 2.0 * m * k * n ) / ( t_save * 1.0e9 ); +``` +This code calls `bli_gemm()` three times and computes the performance, in GFLOPS, of the fastest of the three executions. + +--- + + + # Example code -BLIS provides lots of example code in the [examples/tapi](https://github.com/flame/blis/tree/master/examples/tapi) directory of the BLIS source distribution. The example code in this directory is set up like a tutorial, and so we recommend starting from the beginning. Topics include printing vectors and matrices and calling a representative subset of the computational level-1v, -1m, -2, -3, and utility operations documented above. +BLIS provides lots of example code in the [examples/tapi](https://github.com/flame/blis/tree/master/examples/tapi) directory of the BLIS source distribution. The example code in this directory is set up like a tutorial, and so we recommend starting from the beginning. Topics include printing vectors and matrices and calling a representative subset of the computational level-1v, -1m, -2, -3, and utility operations documented above. Please read the `README` contained within the `examples/tapi` directory for further details. diff --git a/docs/BuildSystem.md b/docs/BuildSystem.md index 84004f8869..5e290d9bbf 100644 --- a/docs/BuildSystem.md +++ b/docs/BuildSystem.md @@ -9,6 +9,9 @@ * **[Step 3b: Testing (optional)](BuildSystem.md#step-3b-testing-optional)** * **[Step 4: Installation](BuildSystem.md#step-4-installation)** * **[Cleaning out build products](BuildSystem.md#cleaning-out-build-products)** +* **[Compiling with BLIS](BuildSystem.md#compiling-with-blis)** + * [Disabling BLAS prototypes](BuildSystem.md#disabling-blas-prototypes) + * [CBLAS](BuildSystem.md#cblas) * **[Linking against BLIS](BuildSystem.md#linking-against-blis)** * **[Uninstalling](BuildSystem.md#uninstalling)** * **[make targets](BuildSystem.md#make-targets)** @@ -24,6 +27,8 @@ The BLIS build system was designed for use with GNU/Linux (or some other sane UN * GNU `bash` (3.2 or later) * GNU `make` (3.81 or later) * a working C99 compiler + * Perl (any version) + * `git` (1.8.5 or later, only required if cloning from Github) BLIS also requires a POSIX threads library at link-time (`-lpthread` or `libpthread.so`). This requirement holds even when configuring BLIS with multithreading disabled (the default) or with multithreading via OpenMP (`--enable-multithreading=openmp`). (Note: BLIS implements basic pthreads functionality automatically for Windows builds via [AppVeyor](https://ci.appveyor.com/project/shpc/blis/).) @@ -71,6 +76,12 @@ Another special configuration (one that, unlike `auto`, _is_ present in `config` If you are a BLIS developer and wish to create your own configuration, either from scratch or using an existing configuration as a starting point, please read the BLIS [Configuration Guide](ConfigurationHowTo.md). +### Multithreading + +Multithreading in BLIS is disabled by default. For more information on enabling multithreading, please read the section of the [Multithreading](Multithreading.md) document titled ["Enabling Multithreading"](Multithreading.md#enabling-multithreading). + +**IMPORTANT**: Even when multithreading is enabled at configure-time, BLIS will default to single-threaded execution at runtime. For more information on the various ways of specifying multithreading at runtime, please read the section titled ["Specifying Multithreading"](Multithreading.md#specifying-multithreading). + ## Step 2: Running `configure` This step should be somewhat familiar to many people who use open source software. To configure the build system, simply run: @@ -83,11 +94,11 @@ Alternatively, `configure` can automatically select a configuration based on you ``` $ ./configure auto ``` -However, as of this writing, only a limited number of architectures are detected. If the `configure` script is not able to detect your architecture, the `generic` configuration will be used. +However, as of this writing, BLIS lacks support for automatically detecting some architectures. If the `configure` script is not able to detect your architecture, the `generic` configuration will be used. Upon running configure, you will get output similar to the following. The exact output will depend on whether you cloned BLIS from a `git` repository or whether you obtained BLIS via a downloadable tarball from the [releases](https://github.com/flame/blis/releases) page. ``` -$ ./configure haswell +$ ./configure --prefix=$HOME/blis haswell configure: using 'gcc' compiler. configure: found gcc version 5.4.0 (maj: 5, min: 4, rev: 0). configure: checking for blacklisted configurations due to gcc 5.4.0. @@ -166,17 +177,11 @@ The installation prefix can be specified via the `--prefix=PREFIX` option: ``` $ ./configure --prefix=/usr ``` -This will cause libraries to eventually be installed (via `make install`) to `PREFIX/lib` and development headers to be installed to `PREFIX/include`. (The default value of `PREFIX` is `$(HOME)/blis`.) You can also specify the library install directory separately from the development header install directory with the `--libdir=LIBDIR` and `--includedir=INCDIR` options, respectively: +This will cause libraries to eventually be installed (via `make install`) to `PREFIX/lib` and development headers to be installed to `PREFIX/include`. (The default value of `PREFIX` is `/usr/local`.) You can also specify the library install directory separately from the development header install directory with the `--libdir=LIBDIR` and `--includedir=INCDIR` options, respectively: ``` $ ./configure --libdir=/usr/lib --includedir=/usr/include ``` -The `--libdir=LIBDIR` and `--includedir=INCDIR` options will override any `PREFIX` path, whether it was specified explicitly via `--prefix` or implicitly (via the default). That is, `LIBDIR` defaults to `PREFIX/lib` and `INCDIR` defaults to `PREFIX/include`, but each will be overriden by their respective `--libdir`/`--includedir` options. So, -``` -$ ./configure --libdir=/usr/lib - -``` -will configure BLIS to install libraries to `/usr/lib` and header files to the default location (`$HOME/blis/include`). -Also, note that `configure` will create any installation directories that do not already exist. +The `--libdir=LIBDIR` and `--includedir=INCDIR` options will override any path implied by `PREFIX`, whether it was specified explicitly via `--prefix` or implicitly (via the default). That is, `LIBDIR` defaults to `EXECPREFIX/lib` (where `EXECPREFIX`, set via `--exec-prefix=EXECPREFIX`, defaults to `PREFIX`) and `INCDIR` defaults to `PREFIX/include`, but `LIBDIR` and `INCDIR` will each be overriden by their respective `--libdir`/`--includedir` options. There is a third related option, `--sharedir=SHAREDIR`, where `SHAREDIR` defaults to `PREFIX/share`. This option specifies the installation directory for certain makefile fragments that contain variables determined by `configure` (e.g. `CC`, `CFLAGS`, `LDFLAGS`, etc.). These files allow certain BLIS makefiles, such as those in the `examples` or `testsuite` directories, to operate on an installed copy of BLIS rather than a local (and possibly uninstalled) copy. For a complete list of supported `configure` options and arguments, run `configure` with the `-h` option: ``` @@ -184,7 +189,6 @@ $ ./configure -h ``` The output from this invocation of `configure` should give you an up-to-date list of options and their descriptions. - ## Step 3: Compilation Once `configure` is finished, you are ready to instantiate (compile) BLIS into a library by running `make`. Running `make` will result in output similar to: @@ -338,6 +342,47 @@ Removing include. Running the `distclean` target is like saying, "Remove anything ever created by the build system." +## Compiling with BLIS + +All BLIS definitions and prototypes may be included in your C source file by including a single header file, `blis.h`: +```c +#include "stdio.h" +#include "stdlib.h" +#include "otherstuff.h" +#include "blis.h" +``` +If the BLAS compatibility layer was enabled at configure-time (as it is by default), then `blis.h` will also provide BLAS prototypes to your source code. + + +### Disabling BLAS prototypes + +Some applications already `#include` a header that contains BLAS prototypes. This can cause problems if those applications also try to `#include` the BLIS header file, as shown above. Suppose for a moment that `otherstuff.h` in the example above already provides BLAS prototypes. +``` +$ gcc -I/path/to/blis -I/path/to/otherstuff -c main.c -o main.o +In file included from main.c:41:0: +/path/to/blis/blis.h:36900:111: error: conflicting declaration of C function ‘int xerbla_(const bla_character*, const bla_integer*, ftnlen)’ + TEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); +``` +If your application is already declaring (prototyping) BLAS functions, then you may disable those prototypes from being defined included within `blis.h`. This prevents `blis.h` from re-declaring those prototypes, or, allows your other header to declare those functions for the first time, depending on the order that you `#include` the headers. +```c +#include "stdio.h" +#include "stdlib.h" +#include "otherstuff.h" +#define BLIS_DISABLE_BLAS_DEFS // disable BLAS prototypes within BLIS. +#include "blis.h" +``` +By `#defining` the `BLIS_DISABLE_BLAS_DEFS` macro, we signal to `blis.h` that it should skip over the BLAS prototypes, but otherwise `#include` everything else as it normally would. Note that `BLIS_DISABLE_BLAS_DEFS` must be `#defined` *prior* to the `#include "blis.h"` directive in order for it to have any effect. + + +### CBLAS + +If you build BLIS with CBLAS enabled and you wish to access CBLAS function prototypes from within your application, you will have to `#include` the `cblas.h` header separately from `blis.h`. +``` +#include "blis.h" +#include "cblas.h" +``` + + ## Linking against BLIS Once you have instantiated (configured and compiled, and perhaps installed) a BLIS library, you can link to it in your application's makefile as you would any other library. The following is an abbreviated makefile for a small hypothetical application that has just two external dependencies: BLIS and the standard C math library. We also link against libpthread since that library has been a runtime dependency of BLIS since 70640a3 (December 2017). @@ -357,7 +402,7 @@ OBJS = main.o util.o other.o %.o: %.c $(CC) $(CFLAGS) -c $< -o $@ -all: $(OBJS) +all: $(OBJS) $(LINKER) $(OBJS) $(BLIS_LIB) $(OTHER_LIBS) -o my_program.x ``` The above example assumes you will want to include BLIS definitions and function prototypes into your application via `#include blis.h`. (If you are only using the BLIS via the BLAS compatibility layer, including `blis.h` is not necessary.) Since BLIS headers are installed into a `blis` subdirectory of `PREFIX/include`, you must make sure that the compiler knows where to find the `blis.h` header file. This is typically accomplished by inserting `#include "blis.h"` into your application's source code files and compiling the code with `-I PREFIX/include/blis`. diff --git a/docs/CodingConventions.md b/docs/CodingConventions.md index 117cc29602..39dd2ab169 100644 --- a/docs/CodingConventions.md +++ b/docs/CodingConventions.md @@ -17,6 +17,8 @@ This wiki describes the coding conventions used in BLIS. Please try to adhere to these conventions when submitting pull requests and/or (if you have permission) committing directly to the repository. +There is some support for these conventions for Emacs editing in the `.dir-locals.el` file, which will affect editing with CC mode in the blis directory. + ## C99 Most of the code in BLIS is written in C, and specifically in ISO C99. This section describes the C coding standards used within BLIS. @@ -32,23 +34,29 @@ Please either use braces to denote the indentation limits of scope, or to enclos foo = 1; } - // This is also fine. + // This is also fine. (Ideal for short conditional bodies.) if ( bli_obj_is_real( x ) ) { foo = 1; return; } // This is bad. Please use one of the two forms above. if ( bli_obj_is_real( x ) ) { foo = 1; } + + // This is (much) worse. Please no. + if ( bli_obj_is_real( x ) ) + { + foo = 1; + } } ``` ### Indentation -If at all possible, **please use tabs to denote changing levels of scope!** If you can't use tabs or doing so would be very inconvenient given your editor and setup, please set your indentation to use exactly four spaces per level of indentation. Below is what it would look like if you used tabs (with a tab width set to four spaces), or four actual spaces per indentation level. +If at all possible, **please use tabs to denote changing levels of scope!** If you can't use tabs or doing so would be very inconvenient given your editor and setup, please set your indentation to use exactly four spaces per level of indentation. Below is what it would look like if you used tabs (with a tab width set to occupy four spaces), or four actual spaces per indentation level. ```c -bool_t bli_obj_is_real( obj_t* x ) +bool bli_obj_is_real( obj_t* x ) { - bool_t r_val; + bool r_val; if ( bli_obj_is_real( x ) ) r_val = TRUE; @@ -59,9 +67,9 @@ bool_t bli_obj_is_real( obj_t* x ) Ideally, tabs should be used to indicate changes in levels of scope, but then spaces should be used for multi-line statements within the same scope. In the example below, I've marked the characters that should be spaces with `.` (with tabs used for the first level of indentation): ```c -bool_t bli_obj_is_complex( obj_t* x ) +bool bli_obj_is_complex( obj_t* x ) { - bool_t r_val; + bool r_val; if ( bli_obj_is_scomplex( x ) || .....bli_obj_is_dcomplex( x ) ) r_val = TRUE; @@ -103,6 +111,10 @@ Please use blank lines to separate lines of code from the next line of code. How // Set the matrix dimensions. bli_obj_set_length( 10, x ); bli_obj_set_width( 5, x ); + + // Set the matrix structure. + bli_obj_set_struc( BLIS_GENERAL, x ); + bli_obj_set_uplo( BLIS_DENSE, x ); } ``` @@ -111,8 +123,8 @@ Please use blank lines to separate lines of code from the next line of code. How Sometimes, to more efficiently display code on the screen, it's helpful to skip certain newlines, such as those in conditional statements. This is fine, just try to line things up in a way that is visually appealing. ```c { - bool_t r_val; - dim_t foo; + bool r_val; + dim_t foo; // This is fine. if ( bli_obj_is_real( x ) ) r_val = TRUE; @@ -132,7 +144,7 @@ Sometimes, to more efficiently display code on the screen, it's helpful to skip ### Whitespace in function calls -For single-line function calls, please **avoid** a space between the last character in the function/macro name and the open parentheses. Also, please do not insert any spaces before commas that separate arguments to a function/macro invocation. +For single-line function calls, **please avoid** a space between the last character in the function/macro name and the open parentheses. Also, please do not insert any spaces before commas that separate arguments to a function/macro invocation. But please **do** insert at least once space after each comma. (I say "at least one" because sometimes it looks nicer to align the commas with those of function calls on lines above or below the function call in question.) Also, please include one space between the opening parentheses and the first argument, and also between the last argument and closing parentheses ```c { obj_t x; @@ -141,11 +153,14 @@ For single-line function calls, please **avoid** a space between the last charac bli_obj_create( BLIS_DOUBLE, 3, 4, 0, 0, &x ); bli_obj_set_length( 10, x ); - // Bad. Please avoid. + // Bad. Please avoid these. bli_obj_set_dt ( BLIS_FLOAT, x ); - - // Bad. Please avoid. bli_obj_set_dt( BLIS_FLOAT , x ); + bli_obj_set_dt(BLIS_FLOAT, x); + bli_obj_set_dt(BLIS_FLOAT,x); + + // Good. + bli_obj_set_dt( BLIS_FLOAT, x ); } ``` For multi-line function calls, please use the following template: @@ -170,15 +185,23 @@ Notice that here, the parentheses are formatted similar to braces. However, noti When defining a function with few arguments, insert a single space after commas and types, and after the first parentheses and before the last parentheses: ```c +// Please write "short" function signatures like this. void bli_obj_set_length( dim_t m, obj_t* a ) { // Body of function } ``` As with single-line function calls, please do not place a space between the last character of the function name and the open parentheses to the argument list! - +```c +// Please avoid this. +void bli_obj_set_length ( dim_t m, obj_t* a ) +{ + // Body of function +} +``` When defining a function with many arguments, especially those that would not comfortably fit in a single 80-character line, you can split the type signature into multiple lines: ```c +// Please write "long" function signatures like this. void bli_gemm ( obj_t* alpha, @@ -194,6 +217,7 @@ void bli_gemm ``` If you are going to use this style of function definition, please indent the parentheses exactly five spaces (don't use tabs here). Then, indent the arguments with an additional two spaces. Thus, parentheses should be in column 6 (counting from 1) and argument types should begin in column 8. Also notice that the number of spaces after each argument's type specifier varies so that the argument names are aligned. If you insert qualifiers such as `restrict`, please right-justify them: ```c +// Please align 'restrict' keywords and variables, as appropriate. void bli_gemm ( obj_t* restrict alpha, @@ -213,18 +237,39 @@ void bli_gemm Please insert whitespace into conditional expressions. ```c { - // Good. - if ( m == 10 && n > 0 ) return; + // Good. + if ( m == 10 && n > 0 ) return; + + // Bad. + if ( m==10 && n>0 ) return; - // Bad. - if ( m==10 && n>0 ) return; + // Worse! + if (m==10&&n>0) return; - // Worse! - if (m==10&&n>0) return; + // Okay, now you're just messing with me. + if(m==10&&n>0)return; } ``` Unlike with the parentheses that surround the argument list of a function call, there should be exactly one space after conditional keywords and the open parentheses for its associated conditional statement: `if (...)`, `else if (...)`, and `while (...)`. +```c +{ + // Good. + if ( ... ) return 0; + else if ( ... ) return 1; + // Good. + while ( ... ) + { + // loop body. + } + + // Good. + do + { + // loop body. + } while ( ... ); +} +``` Sometimes, extra spaces for alignment are desired: ```c { diff --git a/docs/ConfigurationHowTo.md b/docs/ConfigurationHowTo.md index 20aa87e711..dcec7754c6 100644 --- a/docs/ConfigurationHowTo.md +++ b/docs/ConfigurationHowTo.md @@ -212,32 +212,35 @@ Furthermore, if a header file needs to be included, such as `my_malloc.h`, it sh _**SIMD register file.**_ BLIS allows you to specify the _maximum_ number of SIMD registers available for use by your kernels, as well as the _maximum_ size (in bytes) of those registers. These values default to: ```c -#define BLIS_SIMD_NUM_REGISTERS 32 -#define BLIS_SIMD_SIZE 64 +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 +#define BLIS_SIMD_MAX_SIZE 64 ``` These macros are used in computing the maximum amount of temporary storage (typically allocated statically, on the function stack) that will be needed to hold a single micro-tile of any datatype (and for any induced method): ```c -#define BLIS_STACK_BUF_MAX_SIZE ( BLIS_SIMD_NUM_REGISTERS * BLIS_SIMD_SIZE * 2 ) +#define BLIS_STACK_BUF_MAX_SIZE ( BLIS_SIMD_MAX_NUM_REGISTERS * BLIS_SIMD_MAX_SIZE * 2 ) ``` -These temporary buffers are used when handling edge cases (m % _MR_ != 0 || n % _NR_ != 0) within the level-3 macrokernels, and also in the virtual microkernels of various implementations of induced methods for complex matrix multiplication. It is **very important** that these values be set correctly; otherwise, you may experience undefined behavior as stack data is overwritten at run-time. A kernel developer may set `BLIS_SIMD_NUM_REGISTERS` and `BLIS_SIMD_SIZE`, which will indirectly affect `BLIS_STACK_BUF_MAX_SIZE`, or he may set `BLIS_STACK_BUF_MAX_SIZE` directly. Notice that the default values are already set to work with modern x86_64 systems. +These temporary buffers are used when handling edge cases (m % _MR_ != 0 || n % _NR_ != 0) within the level-3 macrokernels, and also in the virtual microkernels of various implementations of induced methods for complex matrix multiplication. It is **very important** that these values be set correctly; otherwise, you may experience undefined behavior as stack data is overwritten at run-time. A kernel developer may set `BLIS_SIMD_MAX_NUM_REGISTERS` and `BLIS_SIMD_MAX_SIZE`, which will indirectly affect `BLIS_STACK_BUF_MAX_SIZE`, or he may set `BLIS_STACK_BUF_MAX_SIZE` directly. Notice that the default values are already set to work with modern x86_64 systems. _**Memory alignment.**_ BLIS implements memory alignment internally, rather than relying on a function such as `posix_memalign()`, and thus it can provide aligned memory even with functions that adhere to the `malloc()` and `free()` API in the standard C library. ```c -#define BLIS_SIMD_ALIGN_SIZE BLIS_SIMD_SIZE +#define BLIS_SIMD_ALIGN_SIZE BLIS_SIMD_MAX_SIZE #define BLIS_PAGE_SIZE 4096 #define BLIS_STACK_BUF_ALIGN_SIZE BLIS_SIMD_ALIGN_SIZE #define BLIS_HEAP_ADDR_ALIGN_SIZE BLIS_SIMD_ALIGN_SIZE #define BLIS_HEAP_STRIDE_ALIGN_SIZE BLIS_SIMD_ALIGN_SIZE -#define BLIS_POOL_ADDR_ALIGN_SIZE BLIS_PAGE_SIZE +#define BLIS_POOL_ADDR_ALIGN_SIZE_A BLIS_PAGE_SIZE +#define BLIS_POOL_ADDR_ALIGN_SIZE_B BLIS_PAGE_SIZE +#define BLIS_POOL_ADDR_ALIGN_SIZE_C BLIS_PAGE_SIZE +#define BLIS_POOL_ADDR_ALIGN_SIZE_GEN BLIS_PAGE_SIZE ``` -The value `BLIS_STACK_BUF_ALIGN_SIZE` defines the alignment of stack memory used as temporary internal buffers, such as for output matrices to the microkernel when computing edge cases. (See [implementation notes](KernelsHowTo#implementation-notes-for-gemm) for the `gemm` microkernel for details.) This value defaults to `BLIS_SIMD_ALIGN_SIZE`, which defaults to `BLIS_SIMD_SIZE`. +The value `BLIS_STACK_BUF_ALIGN_SIZE` defines the alignment of stack memory used as temporary internal buffers, such as for output matrices to the microkernel when computing edge cases. (See [implementation notes](KernelsHowTo#implementation-notes-for-gemm) for the `gemm` microkernel for details.) This value defaults to `BLIS_SIMD_ALIGN_SIZE`, which defaults to `BLIS_SIMD_MAX_SIZE`. The value `BLIS_HEAP_ADDR_ALIGN_SIZE` defines the alignment used when allocating memory via the `malloc()` function defined by `BLIS_MALLOC_USER`. Setting this value to `BLIS_SIMD_ALIGN_SIZE` may speed up certain level-1v and -1f kernels. The value `BLIS_HEAP_STRIDE_ALIGN_SIZE` defines the alignment used for so-called "leading dimensions" (i.e. column strides for column-stored matrices, and row strides for row-stored matrices) when creating BLIS matrices via the object-based API (e.g. `bli_obj_create()`). While setting `BLIS_HEAP_ADDR_ALIGN_SIZE` guarantees alignment for the first column (or row), creating a matrix with certain dimension values (_m_ and _n_) may cause subsequent columns (or rows) to be misaligned. Setting this value to `BLIS_SIMD_ALIGN_SIZE` is usually desirable. Additional alignment may or may not be beneficial. -The value `BLIS_POOL_ADDR_ALIGN_SIZE` defines the alignment used when allocating blocks to the memory pools used to manage internal packing buffers. Any block of memory returned by the memory allocator is guaranteed to be aligned to this value. Aligning these blocks to the virtual memory page size (usually 4096 bytes) is standard practice. +The value `BLIS_POOL_ADDR_ALIGN_SIZE_*` define the alignments used when allocating blocks to the memory pools used to manage internal packing buffers for matrices A, B, C, and for general use. Any block of memory returned by the memory allocator is guaranteed to be aligned to this value. Aligning these blocks to the virtual memory page size (usually 4096 bytes) is standard practice. @@ -635,8 +638,8 @@ Adding support for a new-subconfiguration to BLIS is similar to adding support f ``` and while we're editing the file, we can make any other changes to compiler flags we wish (if any). Similarly, the `bli_family_knl.h` header file should be updated as needed. Since the number of vector registers and the vector register size on `knl` differ from the defaults, we must explicitly set them. (The role of these parameters was explained in a [previous section](ConfigurationHowTo.md#bli_family_h).) Furthermore, provided that a macro `BLIS_NO_HBWMALLOC` is not set, we use a different implementation of `malloc()` and `free()` and `#include` that implementation's header file. ```c - #define BLIS_SIMD_NUM_REGISTERS 32 - #define BLIS_SIMD_SIZE 64 + #define BLIS_SIMD_MAX_NUM_REGISTERS 32 + #define BLIS_SIMD_MAX_SIZE 64 #ifdef BLIS_NO_HBWMALLOC #include @@ -677,14 +680,14 @@ Adding support for a new-subconfiguration to BLIS is similar to adding support f BLIS_ARCH_POWER7, BLIS_ARCH_BGQ, - BLIS_ARCH_GENERIC + BLIS_ARCH_GENERIC, + + BLIS_NUM_ARCHS } arch_t; ``` - Additionally, you'll need to update the definition of `BLIS_NUM_ARCHS` to reflect the new total number of enumerated `arch_t` values: - ```c - #define BLIS_NUM_ARCHS 16 - ``` + Notice that the total number of `arch_t` values, `BLIS_NUM_ARCHS`, is updated automatically. + * **`frame/base/bli_gks.c`**. We must also update the global kernel structure, or gks, to register the new sub-configuration during library initialization. Sub-configuration registration occurs in `bli_gks_init()`. For `knl`, updating this function amounts to inserting the following lines diff --git a/docs/FAQ.md b/docs/FAQ.md index 1e9cd3be4a..3d0852d36f 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -8,6 +8,8 @@ project, as well as those we think a new user or developer might ask. If you do * [Why did you create BLIS?](FAQ.md#why-did-you-create-blis) * [Why should I use BLIS instead of GotoBLAS / OpenBLAS / ATLAS / MKL / ESSL / ACML / Accelerate?](FAQ.md#why-should-i-use-blis-instead-of-gotoblas--openblas--atlas--mkl--essl--acml--accelerate) * [How is BLIS related to FLAME / libflame?](FAQ.md#how-is-blis-related-to-flame--libflame) + * [What is the difference between BLIS and the AMD fork of BLIS found in AOCL?](FAQ.md#what-is-the-difference-between-blis-and-the-amd-fork-of-blis-found-in-aocl) + * [Who do I contact if I have a question about the AMD version of BLIS?](FAQ.md#who-do-i-contact-if-i-have-a-question-about-the-amd-version-of-blis) * [Does BLIS automatically detect my hardware?](FAQ.md#does-blis-automatically-detect-my-hardware) * [I understand that BLIS is mostly a tool for developers?](FAQ.md#i-understand-that-blis-is-mostly-a-tool-for-developers) * [How do I link against BLIS?](FAQ.md#how-do-i-link-against-blis) @@ -16,6 +18,8 @@ project, as well as those we think a new user or developer might ask. If you do * [What is a macrokernel?](FAQ.md#what-is-a-macrokernel) * [What is a context?](FAQ.md#what-is-a-context) * [I am used to thinking in terms of column-major/row-major storage and leading dimensions. What is a "row stride" / "column stride"?](FAQ.md#im-used-to-thinking-in-terms-of-column-majorrow-major-storage-and-leading-dimensions-what-is-a-row-stride--column-stride) + * [I'm somewhat new to this matrix stuff. Can you remind me, what is the difference between a matrix row and a matrix column?](FAQ.md#im-somewhat-new-to-this-matrix-stuff-can-you-remind-me-what-is-the-difference-between-a-matrix-row-and-a-matrix-column) + * [Why does BLIS have vector (level-1v) and matrix (level-1m) variations of most level-1 operations?](FAQ.md#why-does-blis-have-vector-level-1v-and-matrix-level-1m-variations-of-most-level-1-operations) * [What does it mean when a matrix with general stride is column-tilted or row-tilted?](FAQ.md#what-does-it-mean-when-a-matrix-with-general-stride-is-column-tilted-or-row-tilted) * [I am not really interested in all of these newfangled features in BLIS. Can I just use BLIS as a BLAS library?](FAQ.md#im-not-really-interested-in-all-of-these-newfangled-features-in-blis-can-i-just-use-blis-as-a-blas-library) * [What about CBLAS?](FAQ.md#what-about-cblas) @@ -26,17 +30,17 @@ project, as well as those we think a new user or developer might ask. If you do * [Does BLIS work with GPUs?](FAQ.md#does-blis-work-with-gpus) * [Does BLIS work on (some architecture)?](FAQ.md#does-blis-work-on-some-architecture) * [What about distributed-memory parallelism?](FAQ.md#what-about-distributed-memory-parallelism) - * [Can I build BLIS on Windows / Mac OS X?](FAQ.md#can-i-build-blis-on-windows--mac-os-x) + * [Can I build BLIS on Mac OS X?](FAQ.md#can-i-build-blis-on-mac-os-x) + * [Can I build BLIS on Windows?](FAQ.md#can-i-build-blis-on-windows) * [Can I build BLIS as a shared library?](FAQ.md#can-i-build-blis-as-a-shared-library) * [Can I use the mixed domain / mixed precision support in BLIS?](FAQ.md#can-i-use-the-mixed-domain--mixed-precision-support-in-blis) * [Who is involved in the project?](FAQ.md#who-is-involved-in-the-project) * [Who funded the development of BLIS?](FAQ.md#who-funded-the-development-of-blis) * [I found a bug. How do I report it?](FAQ.md#i-found-a-bug-how-do-i-report-it) * [How do I request a new feature?](FAQ.md#how-do-i-request-a-new-feature) + * [I'm a developer and I'd like to study the way matrix multiplication is implemented in BLIS. Where should I start?](FAQ.md#im-a-developer-and-id-like-to-study-the-way-matrix-multiplication-is-implemented-in-blis-where-should-i-start) * [Where did you get the photo for the BLIS logo / mascot?](FAQ.md#where-did-you-get-the-photo-for-the-blis-logo--mascot) - - ### Why did you create BLIS? Initially, BLIS was conceived as simply "BLAS with a more flexible interface". The original BLIS was written as a wrapper layer around BLAS that allowed generalized matrix storage (i.e., separate row and column strides). We also took the opportunity to implement some complex domain features that were missing from the BLAS (mostly related to conjugating input operands). This "proto-BLIS" was deployed in [libflame](http://shpc.ices.utexas.edu/libFLAME.html) to facilitate cleaner implementations of some LAPACK-level operations. @@ -57,7 +61,19 @@ homepage](https://github.com/flame/blis#key-features). But here are a few reason ### How is BLIS related to FLAME / `libflame`? -As explained [above](FAQ.md#why-did-you-create-blis?), BLIS was initially a layer within `libflame` that allowed more convenient interfacing to the BLAS. So in some ways, BLIS is a spin-off project. Prior to developing BLIS, [its author](http://www.cs.utexas.edu/users/field/) worked as the primary maintainer of `libflame`. If you look closely, you can also see that the design of BLIS was influenced by some of the more useful and innovative aspects of `libflame`, such as internal object abstractions and control trees. Also, various members of the [SHPC research group](http://shpc.ices.utexas.edu/people.html) and its [collaborators](http://shpc.ices.utexas.edu/collaborators.html) routinely provide insight, feedback, and also contribute code (especially kernels) to the BLIS project. +As explained [above](FAQ.md#why-did-you-create-blis?), BLIS was initially a layer within `libflame` that allowed more convenient interfacing to the BLAS. So in some ways, BLIS is a spin-off project. Prior to developing BLIS, [its primary author](http://www.cs.utexas.edu/users/field/) worked as the primary maintainer of `libflame`. If you look closely, you can also see that the design of BLIS was influenced by some of the more useful and innovative aspects of `libflame`, such as internal object abstractions and control trees. + +Note that various members of the [SHPC research group](http://shpc.ices.utexas.edu/people.html) and its [collaborators](http://shpc.ices.utexas.edu/collaborators.html) routinely provide insight, feedback, and also contribute code (especially kernels) to the BLIS project. + +### What is the difference between BLIS and the AMD fork of BLIS found in AOCL? + +BLIS, also known as "vanilla BLIS" or "upstream BLIS," is maintained by its [original developer](https://github.com/fgvanzee) (with the [support of others](http://shpc.ices.utexas.edu/collaborators.html)) in the [Science of High-Performance Computing](http://shpc.ices.utexas.edu/) (SHPC) group within the [The Oden Institute for Computational Engineering and Sciences](http://www.oden.utexas.edu/) at [The University of Texas at Austin](http://www.utexas.edu/). In 2015, [AMD](https://www.amd.com/) reorganized many of their software library efforts around existing open source projects. BLIS was chosen as the basis for their [CPU BLAS library](https://developer.amd.com/amd-aocl/blas-library/), and an AMD-maintained [fork of BLIS](https://github.com/amd/blis) was established. + +AMD BLIS sometimes contains certain optimizations specific to AMD hardware. Many of these optimizations are (eventually) merged back into upstream BLIS. However, for various reasons, some changes may remain unique to AMD BLIS for quite some time. Thus, if you want the latest optimizations for AMD hardware, feel free to try AMD BLIS. However, please note that neither The University of Texas at Austin nor BLIS's developers can endorse or offer direct support for any outside fork of BLIS, including AMD BLIS. + +### Who do I contact if I have a question about the AMD version of BLIS? + +For questions or support regarding [AMD's fork of BLIS](https://github.com/amd/blis), please contact the [AMD Optimizing CPU Libraries](https://developer.amd.com/amd-aocl/) group at aoclsupport@amd.com. ### Does BLIS automatically detect my hardware? @@ -67,9 +83,9 @@ If automatic hardware detection is requested at configure-time and the build pro ### I understand that BLIS is mostly a tool for developers? -Yes. In order to achieve high performance, BLIS requires that hand-coded kernels and microkernels be written and referenced in a valid [BLIS configuration](ConfigurationHowTo.md). These components are usually written by developers and then included within BLIS for use by others. +It is certainly the case that BLIS began as a tool targeted at developers. In order to achieve high performance, BLIS requires that hand-coded kernels and microkernels be written and referenced in a valid [BLIS configuration](ConfigurationHowTo.md). These components are usually written by developers and then included within BLIS for use by others. -The good news, however, is that end-users can use BLIS too. Once the aforementioned kernels are integrated into BLIS, they can be used without any developer-level knowledge. Usually, `./configure auto; make; make install` is sufficient for the typical users with typical hardware. +The good news, however, is that BLIS has matured to the point where end-users can use it too! Once the aforementioned kernels are integrated into BLIS, they can be used without any developer-level knowledge, and many kernels have already been added! Usually, `./configure auto; make; make install` is sufficient for the typical users with typical hardware. ### How do I link against BLIS? @@ -77,8 +93,7 @@ Linking against BLIS is easy! Most people can link to it as if it were a generic ### Must I use git? Can I download a tarball? -We **strongly encourage** you to obtain the BLIS source code by cloning a `git` repository (via the [git -clone](BuildSystem.md#obtaining-blis) command). The reason for this is that it will allow you to easily update your local copy of BLIS by executing `git pull`. +We **strongly encourage** you to obtain the BLIS source code by cloning a `git` repository (via the [git clone](BuildSystem.md#obtaining-blis) command). The reason for this is that it will allow you to easily update your local copy of BLIS by executing `git pull`. Tarballs and zip files may be obtained from the [releases](https://github.com/flame/blis/releases) page. @@ -90,33 +105,53 @@ For a more thorough explanation of the microkernel and its role in the overall l ### What is a macrokernel? -The macrokernels are portable codes within the BLIS framework that implement relatively small subproblems within an overall level-3 operation. The overall problem (say, general matrix-matrix multiplication, or `gemm`) is partitioned down, according to cache blocksizes, such that its operands are (1) a suitable size and (2) stored in a special packed format. At that time, the macrokernel is called. The macrokernel is implemented as two loops around the microkernel. +The macrokernels are portable codes within the BLIS framework that implement relatively small subproblems within an overall level-3 operation. The overall problem (say, general matrix-matrix multiplication, or `gemm`) is partitioned down, according to cache blocksizes, such that its `A` and `B` operands are (1) a suitable size and (2) stored in a special packed format. At that time, the macrokernel is called. The macrokernel is implemented as two loops around the microkernel. -The macrokernels in BLIS correspond to the so-called "inner kernels" (or simply "kernels") that formed the fundamental unit of computation in Kazushige Goto's GotoBLAS (and now in the successor library, OpenBLAS). +The macrokernels, along with the microkernel that they call, correspond to the so-called "inner kernels" (or simply "kernels") that formed the fundamental unit of computation in Kazushige Goto's GotoBLAS (and now in the successor library, OpenBLAS). For more information on macrokernels, please read our [ACM TOMS papers](https://github.com/flame/blis#citations). ### What is a context? -As of 0.2.0, BLIS contains a new infrastructure for communicating runtime information (such as kernel addresses and blocksizes) from the highest levels of code all the way down the function stack, even into the kernels themselves. This new data structure is called a *context*, and together with its API, it helped us clean up some hacks and other awkwardness that existed in BLIS prior to 0.2.0. Contexts also lays the groundwork for managing kernels and related kernel information at runtime. +As of 0.2.0, BLIS contains a new infrastructure for communicating runtime information (such as kernel addresses and blocksizes) from the highest levels of code all the way down the function stack, even into the kernels themselves. This new data structure is called a *context* (defined in code as a `cntx_t` type), and together with its API it helped us clean up some hacks and other awkwardness that existed in BLIS prior to 0.2.0. Contexts also lay the groundwork for managing kernels and related kernel information at runtime. If you are a kernel developer, you can usually ignore the `cntx_t*` argument that is passed into each kernel, since the kernels already inherently "know" this information (such as register blocksizes). And if you are a user, and the function you want to call takes a `cntx_t*` argument, you can safely pass in `NULL` and BLIS will automatically build a suitable context for you at runtime. ### I'm used to thinking in terms of column-major/row-major storage and leading dimensions. What is a "row stride" / "column stride"? -Traditional BLAS assumes that matrices are stored in column-major order, where a leading dimension measures the distance from one element to the next element in the same row. But column-major order is really just a special case of BLIS's more generalized storage scheme. +Traditional BLAS assumes that matrices are stored in column-major order (or, as we often say, matrices that are "column-stored"), where a leading dimension measures the distance from one element to the next element in the same row. But column-major order is really just a special case of BLIS's more generalized storage scheme. In generalized storage, we have a row stride and a column stride. The row stride measures the distance in memory between rows (within a single column) while the column stride measures the distance between columns (within a single row). Column-major storage corresponds to the situation where the row stride equals 1. Since the row stride is unit, you only have to track the column stride (i.e., the leading dimension). Similarly, in row-major order, the column stride is equal to 1 and only the row stride must be tracked. BLIS also supports situations where both the row stride and column stride are non-unit. We call this situation "general stride". +### I'm somewhat new to this matrix stuff. Can you remind me, what is the difference between a matrix row and a matrix column? + +Of course! (BLIS's primary author remembers what it was like to get columns and rows confused.) + +Matrix columns consist of elements that are vertically aligned. Matrix rows consist of elements that are horizontally aligned. (One way to remember this distinction is that real-life columns are vertical structures that hold up buildings. A row of seats in a stadium, by contrast, is horizontal to the ground.) + +Furthermore, it is helpful to know that the number of rows in a matrix constitutes its so-called *m* dimension, and the number of columns constitutes its *n* dimension. + +Matrix dimension are always stated as *m x n*: the number of rows *by* the number of columns. + +So, a *3 x 4* matrix contains three rows (each of length four) and four columns (each of length three). + +### Why does BLIS have vector (level-1v) and matrix (level-1m) variations of most level-1 operations? + +At first glance, it might appear that an element-wise operation such as `copym` or `axpym` would be sufficiently general purpose to cover the cases where the operands are vectors. After all, an *m x 1* matrix can be viewed as a vector of length m and vice versa. But in BLIS, operations on vectors are treated slightly differently than operations on matrices. + +If an application wishes to perform an element-wise operation on two objects, and the application calls a level-1m operation, the dimensions of those objects must be conformal, or "match up" (after any transposition implied by the object properties). This includes situations where one of the dimensions is unit. + +However, if an application instead decides to perform an element-wise operation on two objects, and the application calls a level-1v operation, the dimension constraints are slightly relaxed. In this scenario, BLIS only checks that the vector *lengths* are equal. This allows for the vectors to have different orientations (row vs column) while still being considered conformal. So, you could perform a `copyv` operation to copy from an *m x 1* vector to a *1 x m* vector. A `copym` operation on such objects would not be allowed (unless it was executed with the source object containing an implicit transposition). + ### What does it mean when a matrix with general stride is column-tilted or row-tilted? When a matrix is stored with general stride, both the row stride and column stride (let's call them `rs` and `cs`) are non-unit. When `rs` < `cs`, we call the general stride matrix "column-tilted" because it is "closer" to being column-stored (than row-stored). Similarly, when `rs` > `cs`, the matrix is "row-tilted" because it is closer to being row-stored. ### I'm not really interested in all of these newfangled features in BLIS. Can I just use BLIS as a BLAS library? -Absolutely. Just link your application to BLIS the same way you would link to a BLAS library. For a simple linking example, see the [Linking to BLIS](KernelsHowTo.md#linking-to-blis) section of the BLIS [Build System](BuildSystem.md) guide. +Absolutely! Just link your application to BLIS the same way you would link to a BLAS library. For a simple linking example, see the [Linking to BLIS](KernelsHowTo.md#linking-to-blis) section of the BLIS [Build System](BuildSystem.md) guide. ### What about CBLAS? @@ -126,17 +161,19 @@ BLIS also contains an optional CBLAS compatibility layer, which leverages the BL In principle, BLIS's native (and BLAS-like) [typed API](BLISTypedAPI) can be called from Fortran. However, you must ensure that the size of the integer in BLIS is equal to the size of integer used by your Fortran program/compiler/environment. The size of BLIS integers is determined at configure-time. Please see `./configure --help` for the syntax for options related to integer sizes. +You may also want to confirm that your Fortran compiler doesn't perform any name-mangling of called functions or subroutines (such as with additional underscores beyond the single trailing underscore found in the BLAS APIs), and if so, take steps to disable this additional name-mangling. For example, if your source code calls `dgemm()` but your Fortran compiler name-mangles that call to `_dgemm_()` or `dgemm__()`, your program will fail to link against BLIS since BLIS only defines `dgemm_()`. + As for bindings to other languages, please contact the [blis-devel](http://groups.google.com/group/blis-devel) mailing list. ### Do I need to call initialization/finalization functions before being able to use BLIS from my application? -Originally, BLIS did indeed require the application to explicitly setup (initialize) various internal data structures via `bli_init()`. Likewise, calling `bli_finalize()` was recommended to cleanup (finalize) the library. However, since commit 9804adf (circa December 2017), BLIS has implemented self-initialization. These explicit calls to `bli_init()` and `bli_finalize()` are no longer necessary, though experts may still use them in special cases to control the allocation and freeing of resources. This topic is discussed in the BLIS [typed API reference](BLISTypedAPI.md#initialization-and-cleanup). +Originally, BLIS did indeed require the application to explicitly setup (initialize) various internal data structures via `bli_init()`. Likewise, calling `bli_finalize()` was recommended to cleanup (finalize) the library. However, since commit `9804adf` (circa December 2017), BLIS has implemented self-initialization. These explicit calls to `bli_init()` and `bli_finalize()` are no longer necessary, though experts may still use them in special cases to control the allocation and freeing of resources. This topic is discussed in the BLIS [typed API reference](BLISTypedAPI.md#initialization-and-cleanup). ### Does BLIS support multithreading? Yes! BLIS supports multithreading (via OpenMP or POSIX threads) for all of its level-3 operations. For more information on enabling and controlling multithreading, please see the [Multithreading](Multithreading.md) guide. -BLIS is also thread-safe so that you can call BLIS from threads within a multithreaded library or application. BLIS derives is thread-safety via unconditional use of features present in POSIX threads (pthreads). These pthreads features are employed for thread-safety regardless of whether BLIS is configured for OpenMP multithreading, pthreads multithreading, or single-threaded execution. +BLIS is also thread-safe so that you can call BLIS from threads within a multithreaded library or application. BLIS derives its thread-safety via unconditional use of features present in POSIX threads (pthreads). These pthreads features are employed for thread-safety regardless of whether BLIS is configured for OpenMP multithreading, pthreads multithreading, or single-threaded execution. ### Does BLIS support NUMA environments? @@ -144,7 +181,7 @@ We have integrated some early foundational support for NUMA *development*, but c ### Does BLIS work with GPUs? -BLIS does not currently support graphical processing units (GPUs). +BLIS does not currently support graphical processing units (GPUs). However, others have applied the BLIS approach towards frameworks that provide BLAS-like functionality on GPUs. To see how NVIDIA's implementation compares to an analogous approach based on the principles that underlie BLIS, please see a paper by some of our collaborators, ["Implementing Strassen’s Algorithm with CUTLASS on NVIDIA Volta GPUs"](https://apps.cs.utexas.edu/apps/sites/default/files/tech_reports/GPUStrassen.pdf). ### Does BLIS work on _(some architecture)_? @@ -154,13 +191,30 @@ Please see the BLIS [Hardware Support](HardwareSupport.md) guide for a full list No. BLIS is a framework for sequential and shared-memory/multicore implementations of BLAS-like operations. If you need distributed-memory dense linear algebra implementations, we recommend the [Elemental](http://libelemental.org/) library. -### Can I build BLIS on Windows / Mac OS X? +### Can I build BLIS on Mac OS X? + +BLIS was designed for use in a GNU/Linux environment. However, we've gone to great lengths to keep BLIS compatible with other UNIX-like systems as well, such as BSD and OS X. System software requirements for UNIX-like systems are discussed in the BLIS [Build System](BuildSystem.md) guide. + +### Can I build BLIS on Windows? -BLIS was designed for use in a GNU/Linux environment. However, we've gone to greath lengths to keep BLIS compatible with other UNIX-like systems as well, such as BSD and OS X. System software requirements for UNIX-like systems are discussed in the BLIS [Build System](BuildSystem.md) guide. +If all you need is a Windows DLL of BLIS, you may be in luck! BLIS uses [AppVeyor](https://ci.appveyor.com/) to automatically produces dynamically-linked libraries, which are preserved on the site as "artifacts". To try it out, just visit the [BLIS AppVeyor page](https://ci.appveyor.com/project/shpc/blis/), click on the `LIB_TYPE=shared` link for the most recent build, and then click on "Artifacts". If you would like to provide us feedback, you may do so by [opening an issue](http://github.com/flame/blis/issues), or you can join the [blis-devel](http://groups.google.com/group/blis-devel) mailing list and send us a message. -Support for building in Windows is not directly supported. However, Windows 10 now provides a Linux-like environment. We suspect this is the best route for those trying to build BLIS in Windows. +If you want to build on Windows, there are two options: -If all you need is a Windows DLL of BLIS, you may be in luck! BLIS uses [AppVeyor](https://ci.appveyor.com/) to automatically produces dynamically-linked libraries, which are preserved on the site as "artifacts". To try it out, just visit the [BLIS AppVeyor page](https://ci.appveyor.com/project/shpc/blis/), click on the `LIB_TYPE=shared` link for the most recent build, and then click on "Artifacts". And if you'd like to share your experiences, please join the [blis-devel](http://groups.google.com/group/blis-devel) mailing list and send us a message! +1. MSVC ABI compatible DLL with clang + + If you want BLIS to be compatible with DLLs built by MSVC, you need to use `clang.exe` to build BLIS as BLIS does not support building with Visual Studio C compiler (``cl.exe``). To build BLIS, you need a recent clang from [LLVM](https://releases.llvm.org/download.html), an [MSYS2](https://www.msys2.org/) environment (for build tools like `sed`, `bash`), a Visual Studio 2015 or later environment (for C standard library) and Windows SDK. + To build `BLIS`, + * Activate the Visual Studio environment from a command prompt + Run `call C:\Program Files (x86)\Microsoft Visual Studio\2019\Professional\VC\Auxiliary\Build\vcvarsall.bat x64` + * Start the bash shell from the same command prompt. (Run `bash.exe`) + * Run `export AR=llvm-ar AS=llvm-as RANLIB=echo CC=clang CXX=clang++` + * Run `./configure --prefix=/c/blis/ --disable-static --enable-shared auto` + * Run `make -j install` + +2. MinGW DLL + + This is the easiest option to compile BLIS on windows, but the DLL might not be compatible with other programs compiled with MSVC. To build `BLIS`, install [MSYS2](https://www.msys2.org) and `mingw-w64` compilers. Then start a `bash` shell from MSYS2 and follow the instructions for the Linux build. ### Can I build BLIS as a shared library? @@ -168,7 +222,7 @@ Yes. By default, most configurations output only a static library archive (e.g. ### Can I use the mixed domain / mixed precision support in BLIS? -Yes! As of 5fec95b (circa October 2018), BLIS supports mixed-datatype (mixed domain and/or mixed precision) computation via the `gemm` operation. Documentation on utilizing this new functionality is provided via the [MixedDatatype.md](docs/MixedDatatypes.md) document in the source distribution. +Yes! As of 5fec95b (circa October 2018), BLIS supports mixed-datatype (mixed domain and/or mixed precision) computation via the `gemm` operation. Documentation on utilizing this new functionality is provided via the [MixedDatatype.md](MixedDatatypes.md) document in the source distribution. If this feature is important or useful to your work, we would love to hear from you. Please contact us via the [blis-devel](http://groups.google.com/group/blis-devel) mailing list and tell us about your application and why you need/want support for BLAS-like operations with mixed-domain/mixed-precision operands. @@ -179,21 +233,27 @@ Lots of people! For a full list of those involved, see the ### Who funded the development of BLIS? -BLIS was primarily funded by grants from [Microsoft](https://www.microsoft.com/), -[Intel](https://www.intel.com/), [Texas -Instruments](https://www.ti.com/), [AMD](https://www.amd.com/), [Huawei](https://www.hauwei.com/us/), and [Oracle](https://www.oracle.com/) as well as grants from the [National Science Foundation](http://www.nsf.gov/) (Awards CCF-0917167 ACI-1148125/1340293, and CCF-1320112). +BLIS was primarily funded by a variety of gifts/grants from industry and the National Science Foundation. Please see the "Funding" section of the [BLIS homepage](https://github.com/flame/blis#funding) for more details. Reminder: _Any opinions, findings and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation (NSF)._ ### I found a bug. How do I report it? -If you think you've found a bug, we request that you [open an issue](http://github.com/flame/blis/issues). Don't be shy! Really, it's the best and most convenient way for us to track your issues/bugs/concerns. Other discussions that are not primarily bug-reports should take place via the [blis-devel](http://groups.google.com/group/blis-devel) mailing list. +If you think you've found a bug, we request that you [open an issue](http://github.com/flame/blis/issues). Don't be shy! Really, it's the best and most convenient way for us to track your issues/bugs/concerns. ### How do I request a new feature? Feature requests should also be submitted by [opening a new issue](http://github.com/flame/blis/issues). +### I'm a developer and I'd like to study the way matrix multiplication is implemented in BLIS. Where should I start? + +Great question! The first thing you should know is that the core framework of [level-3 operations](BLISTypedAPI.md#operation-index) was *not* designed to be used to teach or explain a high-performance implementation of matrix multiplication. Rather, it was designed to encode the family of level-3 operations with as little code duplication as possible. Because of this, and also for historical/evolutionary reasons, it can be a little difficult to trace the execution of, say, `gemm` from within the core framework. + +Thankfully, we have an alternative environment in which experts, application developers, and other curious individuals can study BLIS's matrix multiplication implementation. This so-called "sandbox" is a simplified collection of code that strips away much of the framework complexity while also maintaining local definitions for many of the interesting bits. You may find this `gemmlike` sandbox in `sandbox/gemmlike`. + +Sandboxes go beyond the scope of this FAQ. For an introduction, please refer to the [Sandboxes](Sandboxes.md) document, and/or contact the BLIS developers for more information. + ### Where did you get the photo for the BLIS logo / mascot? -The sleeping ["BLIS cat"](https://github.com/flame/blis/blob/master/README.md) photo was taken by Petar Mitchev and is used with his permission. +The sleeping ["BLIS cat"](README.md) photo was taken by Petar Mitchev and is used with his permission. diff --git a/docs/HardwareSupport.md b/docs/HardwareSupport.md index 41036d51c8..944cfa8ee1 100644 --- a/docs/HardwareSupport.md +++ b/docs/HardwareSupport.md @@ -12,9 +12,10 @@ The following table lists architectures for which there exist optimized level-3 A few remarks / reminders: * Optimizing only the [gemm microkernel](KernelsHowTo.md#gemm-microkernel) will result in optimal performance for all [level-3 operations](BLISTypedAPI#level-3-operations) except `trsm` (which will typically achieve 60 - 80% of attainable peak performance). * The [trsm](BLISTypedAPI#trsm) operation needs the [gemmtrsm microkernel(s)](KernelsHowTo.md#gemmtrsm-microkernels), in addition to the aforementioned [gemm microkernel](KernelsHowTo.md#gemm-microkernel), in order reach optimal performance. - * Induced complex (1m) implementations are employed in all situations where the real domain [gemm microkernel](KernelsHowTo.md#gemm-microkernel) of the corresponding precision is available. Please see our [ACM TOMS article on the 1m method](https://github.com/flame/blis#citations) for more info on this topic. - * Some microarchitectures use the same sub-configuration. This is not a typo. For example, Haswell and Broadwell systems as well as "desktop" (non-server) versions of Skylake, Kabylake, and Coffeelake all use the `haswell` sub-configuration and the kernels registered therein. + * Induced complex (1m) implementations are employed in all situations where the real domain [gemm microkernel](KernelsHowTo.md#gemm-microkernel) of the corresponding precision is available, but the "native" complex domain gemm microkernel is unavailable. Note that the table below lists native kernels, so if a microarchitecture lists only `sd`, support for both `c` and `z` datatypes will be provided via the 1m method. (Note: most people cannot tell the difference between native and 1m-based performance.) Please see our [ACM TOMS article on the 1m method](https://github.com/flame/blis#citations) for more info on this topic. + * Some microarchitectures use the same sub-configuration. *This is not a typo.* For example, Haswell and Broadwell systems as well as "desktop" (non-server) versions of Skylake, Kaby Lake, and Coffee Lake all use the `haswell` sub-configuration and the kernels registered therein. Microkernels can be recycled in this manner because the key detail that determines level-3 performance outcomes is actually the vector ISA, not the microarchitecture. In the previous example, all of the microarchitectures listed support AVX2 (but not AVX-512), and therefore they can reuse the same microkernels. * Remember that you (usually) don't have to choose your sub-configuration manually! Instead, you can always request configure-time hardware detection via `./configure auto`. This will defer to internal logic (based on CPUID for x86_64 systems) that will attempt to choose the appropriate sub-configuration automatically. + * There is a difficulty in automatically choosing the ideal sub-configuration for use on Skylake-X systems, which may have one or two FMA units. The `skx` sub-configuration is only beneficial when used on hardware with two FMA units. Otherwise the hardware is treated as a "desktop" Skylake system, which uses the `haswell` sub-configuration. Furthermore, the number of units can't be queried directly; instead, we rely on a manually-maintained list of CPU models (via logic in `frame/base/bli_cpuid.c`), which may be incorrect for new processors, particularly Gold models. In that case, you can either fix the code (and please raise an issue!) or manually target the `skx` at configure-time (i.e., `./configure [options] skx`). If your performance seems low, you can set `export BLIS_ARCH_DEBUG=1`, which will cause BLIS to output some basic debugging info to `stderr` that will reveal whether your system was detected as having one or two VPUs (FMA units). | Vendor/Microarchitecture | BLIS sub-configuration | `gemm` | `gemmtrsm` | |:-------------------------------------|:-----------------------|:-------|:-----------| @@ -23,16 +24,19 @@ A few remarks / reminders: | AMD Steamroller (AVX/FMA3) | `steamroller` | `sdcz` | | | AMD Excavator (AVX/FMA3) | `excavator` | `sdcz` | | | AMD Zen (AVX/FMA3) | `zen` | `sdcz` | `sd` | -| Intel Core2 (SSE3) | `penryn` | `sd` | `d` | +| Intel Core2 (SSE3) | `penryn` | `sd` | `d` | | Intel Sandy/Ivy Bridge (AVX/FMA3) | `sandybridge` | `sdcz` | | | Intel Haswell, Broadwell (AVX/FMA3) | `haswell` | `sdcz` | `sd` | -| Intel Sky/Kaby/Coffeelake (AVX/FMA3) | `haswell` | `sdcz` | `sd` | +| Intel Sky/Kaby/CoffeeLake (AVX/FMA3) | `haswell` | `sdcz` | `sd` | | Intel Knights Landing (AVX-512/FMA3) | `knl` | `sd` | | -| Intel SkylakeX (AVX-512/FMA3) | `skx` | `sd` | | +| Intel SkylakeX (AVX-512/2×FMA3) | `skx` | `sd` | | +| Intel SkylakeX (AVX-512/1×FMA3) | `haswell` | `sdcz` | `sd` | | ARMv7 Cortex-A9 (NEON) | `cortex-a9` | `sd` | | | ARMv7 Cortex-A15 (NEON) | `cortex-a15` | `sd` | | | ARMv8 Cortex-A53 (NEON) | `cortex-a53` | `sd` | | | ARMv8 Cortex-A57 (NEON) | `cortex-a57` | `sd` | | +| ARMv8.1 ThunderX2 (NEON) | `thunderx2` | `sd` | | +| ARMv8.1 A64FX (SVE) | `a64fx` | `d` | | | IBM Blue Gene/Q (QPX int) | `bgq` | `d` | | | IBM Power7 (QPX int) | `power7` | `d` | | | template (C99) | `template` | `sdcz` | `sdcz` | diff --git a/docs/KernelsHowTo.md b/docs/KernelsHowTo.md index e3dfd125f1..6e84db8e76 100644 --- a/docs/KernelsHowTo.md +++ b/docs/KernelsHowTo.md @@ -113,7 +113,7 @@ Note that all kernels, whether they be reference implementations or based on ful The first step is to obtain a valid context. Contexts store all of the information specific to a particular sub-configuration (usually loosely specific to a -microarchitecture or group of closely-related microarchitectuers). If a context is +microarchitecture or group of closely-related microarchitectures). If a context is not already available in your current scope, a default context for the hardware for which BLIS was configured (or, in the case of multi-configuration builds, the hardware on which BLIS is currently running) may be queried via: @@ -229,7 +229,7 @@ This section seeks to provide developers with a complete reference for each of t The function prototypes in this section follow the same guidelines as those listed in the [BLIS typed API reference](BLISTypedAPI.md#Notes_for_using_this_reference). Namely: * Any occurrence of `?` should be replaced with `s`, `d`, `c`, or `z` to form an actual function name. - * Any occurrence of `ctype` should be replaced with the actual C type corresponding to the datatype instance in question. + * Any occurrence of `ctype` should be replaced with the actual C99 language type corresponding to the datatype instance in question. * Some matrix arguments have associated row and column strides arguments that proceed them, typically listed as `rsX` and `csX` for a given matrix `X`. Row strides are always listed first, and column strides are always listed second. The semantic meaning of a row stride is "the distance, in units of elements, from any given element to the corresponding element (within the same column) of the next row," and the meaning of a column stride is "the distance, in units of elements, from any given element to the corresponding element (within the same row) of the next column." Thus, unit row stride implies column-major storage and unit column stride implies row-major storage. * All occurrences of `alpha` and `beta` parameters are scalars. @@ -248,6 +248,8 @@ This section describes in detail the various level-3 microkernels supported by B ```c void bli_?gemm_ ( + dim_t m, + dim_t n, dim_t k, ctype* restrict alpha, ctype* restrict a1, @@ -264,6 +266,8 @@ where `` is implementation-dependent. (Recall that the precise ` ```c void bli_?gemm_ukernel ( + dim_t m, + dim_t n, dim_t k, ctype* restrict alpha, ctype* restrict a1, @@ -274,23 +278,28 @@ void bli_?gemm_ukernel cntx_t* restrict cntx ); ``` +This function simply queries a microkernel function pointer from the context specified by `cntx`. Note that in the case of either method of calling the microkernel, `cntx` must be a valid pointer. (Passing in `NULL` will *not* result in a default context being used.) The `gemm` microkernel, sometimes simply referred to as "the BLIS microkernel" or "the microkernel", performs the following operation: ``` - C11 := beta * C11 + A1 * B1 + C11 := beta * C11 + alpha * A1 * B1 ``` -where `A1` is an _MR x k_ "micropanel" matrix stored in packed (column-wise) format, `B1` is a _k x NR_ "micropanel" matrix stored in packed (row-wise) format, `C11` is an _MR x NR_ general matrix stored according to its row and column strides `rsc` and `csc`, and `alpha` and beta are scalars. +where `A1` is an _m x k_ "micropanel" matrix stored in packed (column-wise) format, `B1` is a _k x n_ "micropanel" matrix stored in packed (row-wise) format, `C11` is an _m x n_ "microtile" matrix stored according to its row and column strides `rsc` and `csc`, and `alpha` and beta are scalars. -_MR_ and _NR_ are the register blocksizes associated with the microkernel. They are chosen by the developer when the microkernel is written and then encoded into a BLIS configuration, which will reference the microkernel when the BLIS framework is instantiated into a library. For more information on setting register blocksizes and related constants, please see the [BLIS developer configuration guide](ConfigurationHowTo.md). +Here, _m <= MR_ and _n <= NR_, where _MR_ and _NR_ are the register blocksizes associated with the microkernel. They are chosen by the developer when the microkernel is written and then encoded into a BLIS configuration, which will reference the microkernel when the BLIS framework is instantiated into a library. For more information on setting register blocksizes and related constants, please see the [BLIS developer configuration guide](ConfigurationHowTo.md). + +**Note:** For many years, BLIS defined its microkernel to operate on microtiles whose dimensions were *exactly* _MR x NR_. However, as of commit 54fa28b, we have augmented the `gemm` microkernel API to pass in _m_ and _n_ dimensions as well as _k_. This change was made as part of our decision to move edge-case handling into the microkernel, whereas previously it was handled outside of the microkernel, within the portable parts of BLIS framework. And while this does mean additional complexity for microkernel authors, adding generic edge-case handling can be done in a relatively painless manner by employing some pre-defined preprocessor macros (which are defined in `bli_edge_case_macro_defs.h`). For examples of how to use these macros, please see the beginning and end of existing microkernel functions residing within the `kernels` directory. Parameters: + * `m`: The number of rows of `C11` and `A1`. + * `n`: The number of columns of `C11` and `B1`. * `k`: The number of columns of `A1` and rows of `B1`. * `alpha`: The address of a scalar to the `A1 * B1` product. - * `a1`: The address of a micropanel of matrix `A` of dimension _MR x k_, stored by columns with leading dimension _PACKMR_, where typically _PACKMR_ = _MR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKMR_.) - * `b1`: The address of a micropanel of matrix `B` of dimension _k x NR_, stored by rows with leading dimension _PACKNR_, where typically _PACKNR_ = _NR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKNR_.) + * `a1`: The address of a micropanel of matrix `A` of dimension _m x k_ (where _m <= MR_), stored by columns with leading dimension _PACKMR_, where typically _PACKMR_ = _MR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKMR_.) + * `b1`: The address of a micropanel of matrix `B` of dimension _k x n_ (where _n <= NR_), stored by rows with leading dimension _PACKNR_, where typically _PACKNR_ = _NR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKNR_.) * `beta`: The address of a scalar to the input value of matrix `C11`. * `c11`: The address of a matrix `C11` of dimension _MR x NR_, stored according to `rsc` and `csc`. * `rsc`: The row stride of matrix `C11` (ie: the distance to the next row, in units of matrix elements). @@ -321,24 +330,24 @@ The diagram below shows the packed micropanel operands and how elements of each #### Implementation Notes for gemm - * **Register blocksizes.** The register blocksizes `MR` and `NR`, corresponding to the number of *logical* rows in `a1` and columns in `b1`, respectively, are defined in the context and may be queried via `bli_cntx_get_blksz_def_dt()`. However, you shouldn't need to query these values since the implementation inherently "knows" them already. - * **Leading dimensions of `a1` and `b1`: _PACKMR_ and _PACKNR_.** The packed micropanels `a1` and `b1` are simply stored in column-major and row-major order, respectively. Usually, the width of either micropanel (ie: the number of logical rows of `a1`, or _MR_, and the number of columns of `b1`, or _NR_) is equal to that micropanel's so-called "leading dimension", or number of *physical* rows. Sometimes, it may be beneficial to specify a leading dimension that is larger than the panel width. This may be desirable because it allows each column of `a1` or row of `b1` to maintain a certain alignment in memory that would not otherwise be maintained by _MR_ and/or _NR_. In this case, you should index through `a1` and `b1` using the values _PACKMR_ and _PACKNR_, respectively (which are stored in the context as the blocksize "maximums" associated with the `bszid_t` values `BLIS_MR` and `BLIS_NR`). These values are defined in the context and may be queried via `bli_cntx_get_blksz_max_dt()`. However, you shouldn't need to query these values since the implementation inherently "knows" them already. - * **Storage preference of `c11`.** Usually, an optimized `gemm` microkernel will have a "preferred" storage format for `C11`--typically either contiguous row-storage (i.e. `cs_c` = 1) or contiguous column-storage (i.e. `rs_c` = 1). This preference comes from how the microkernel is most efficiently able to load/store elements of `C11` from/to memory. Most microkernels use vector instructions to access contiguous columns (or column segments) of `C11`. However, the developer may decide that accessing contiguous rows (or row segments) is more desirable. If this is the case, this preference should be indicated via the `bool_t` argument when registering microkernels via `bli_cntx_set_l3_nat_ukrs()`--`TRUE` indicating a row preference and `FALSE` indicating a column preference. Properly setting this property allows the framework to perform a runtime optimization that will ensure the microkernel preference is honored, if at all possible. - * **Edge cases in _MR_, _NR_ dimensions.** Sometimes the microkernel will be called with micropanels `a1` and `b1` that correspond to edge cases, where only partial results are needed. Zero-padding is handled automatically by the packing function to facilitate reuse of the same microkernel. Similarly, the logic for computing to temporary storage and then saving only the elements that correspond to elements of `C11` that exist (at the edges) is handled automatically within the macrokernel. - * **Alignment of `a1` and `b1`.** By default, the alignment of addresses `a1` and `b1` are aligned only to `sizeof(type)`. If `BLIS_POOL_ADDR_ALIGN_SIZE` is set to some larger multiple of `sizeof(type)`, such as the page size, then the *first* `a1` and `b1` micropanels will be aligned to that value, but subsequent micropanels will only be aligned to `sizeof(type)`, or, if `BLIS_POOL_ADDR_ALIGN_SIZE` is a multiple of `PACKMR` and `PACKNR`, then subsequent micropanels `a1` and `b1` will be aligned to `PACKMR * sizeof(type)` and `PACKNR * sizeof(type)`, respectively. - * **Unrolling loops.** As a general rule of thumb, the loop over _k_ is sometimes moderately unrolled; for example, in our experience, an unrolling factor of _u_ = 4 is fairly common. If unrolling is applied in the _k_ dimension, edge cases must be handled to support values of _k_ that are not multiples of _u_. It is nearly universally true that there should be no loops in the _MR_ or _NR_ directions; in other words, iteration over these dimensions should always be fully unrolled (within the loop over _k_). + * **Register blocksizes.** The register blocksizes `MR` and `NR`, corresponding to the maximum number of *logical* rows in `a1` and columns in `b1`, respectively, are defined in the context and may be queried via `bli_cntx_get_blksz_def_dt()`. However, you shouldn't need to query these values since the implementation inherently "knows" them already. + * **Leading dimensions of `a1` and `b1`: _PACKMR_ and _PACKNR_.** The packed micropanels `a1` and `b1` are simply stored in column-major and row-major order, respectively. Usually, the width of either micropanel (ie: the number of *logical* rows of `a1` and the number of columns of `b1`) is equal to that micropanel's so-called "leading dimension", or number of *physical* rows. Sometimes, it may be beneficial to specify a leading dimension that is larger than the panel width. This may be desirable because it allows each column of `a1` or row of `b1` to maintain a certain alignment in memory that would not otherwise be maintained by _MR_ and/or _NR_, which would othewise serve as the maximum value for each micropanel, respectively. If you want your microkernel to support _MR < PACKMR_ or _NR < PACKNR_, you should index through columns of `a1` and rows of `b1` using the values _PACKMR_ and _PACKNR_, respectively (which are stored in the context as the blocksize "maximums" associated with the `bszid_t` values `BLIS_MR` and `BLIS_NR`). These values are defined in the context and may be queried via `bli_cntx_get_blksz_max_dt()`. However, you shouldn't need to query these values since the microkernel implementation inherently must "know" them already. + * **Storage preference of `c11`.** Usually, an optimized `gemm` microkernel will have a "preferred" storage format for `C11`--typically either contiguous row-storage (i.e. `cs_c` = 1) or contiguous column-storage (i.e. `rs_c` = 1). This preference comes from how the microkernel is most efficiently able to load/store elements of `C11` from/to memory. Most microkernels use vector instructions to access contiguous columns (or column segments) of `C11`. However, the developer may decide that accessing contiguous rows (or row segments) is more desirable. If this is the case, this preference should be indicated via the `bool` argument when registering microkernels via `bli_cntx_set_l3_nat_ukrs()`--`TRUE` indicating a row preference and `FALSE` indicating a column preference. Properly setting this property allows the framework to perform a runtime optimization that will ensure the microkernel preference is honored, if at all possible. + * **Edge cases in _MR_, _NR_ dimensions.** Sometimes the microkernel will be called with micropanels `a1` and `b1` that correspond to edge cases, where only partial results are needed. This edge-case handling was once performed by the framework automatically. However, as of commit 54fa28b, edge-case handling is the responsiblity of the microkernel. This means that the kernel author will need to handle all possible values of _m_ and _n_ that are equal to **or** less than _MR_ and _NR_, respectively. Fortunately, this can be implemented outside of the assembly region of the microkernel with preprocessor macros. Please reference the existing microkernels in the `kernels` directory for examples of how this is done. (The macros that are now employed by most of BLIS's microkernels are defined in `bli_edge_case_macro_defs.h`.) + * **Alignment of `a1` and `b1`.** By default, the alignment of addresses `a1` and `b1` are aligned to the page size (4096 bytes). These alignment factors are set by `BLIS_POOL_ADDR_ALIGN_SIZE_A` and `BLIS_POOL_ADDR_ALIGN_SIZE_B`, respectively. Note that these alignment factors control only the alignment of the *first* micropanel within a given packed blockof matrix `A` or packed row-panel of matrix `B`. Subsequent micropanels will only be aligned to `sizeof(type)`, or, if `BLIS_POOL_ADDR_ALIGN_SIZE_A` is a multiple of `PACKMR` and/or `BLIS_POOL_ADDR_ALIGN_SIZE_B` is a multiple of `PACKNR`, then subsequent micropanels `a1` and/or `b1` will be aligned to `PACKMR * sizeof(type)` and/or `PACKNR * sizeof(type)`, respectively. + * **Unrolling loops.** As a general rule of thumb, the loop over _k_ is sometimes moderately unrolled; for example, in our experience, an unrolling factor of _u_ = 4 is fairly common. If unrolling is applied in the _k_ dimension, edge cases must be handled to support values of _k_ that are not multiples of _u_. It is nearly universally true that the microkernel should not contain loops in the _m_ or _n_ directions; in other words, iteration over these dimensions should always be fully unrolled (within the loop over _k_). * **Zero `beta`.** If `beta` = 0.0 (or 0.0 + 0.0i for complex datatypes), then the microkernel should NOT use it explicitly, as `C11` may contain uninitialized memory (including elements containing `NaN` or `Inf`). This case should be detected and handled separately by overwriting `C11` with the `alpha * A1 * B1` product. #### Using the auxinfo\_t object -Each microkernel ([gemm](KernelsHowTo.md#gemm-microkernel), [trsm](KernelsHowTo.md#trsm_microkernels), and [gemmtrsm](KernelsHowTo.md#gemmtrsm-microkernels)) takes as its last argument a pointer of type `auxinfo_t`. This BLIS-defined type is defined as a `struct` whose fields contain auxiliary values that may be useful to some microkernel authors, particularly when implementing certain optimization techniques. BLIS provides kernel authors access to the fields of the `auxinfo_t` object via the following function-like preprocessor macros. Each macro takes a single argument, the `auxinfo_t` pointer, and returns one of the values stored within the object. +Each microkernel ([gemm](KernelsHowTo.md#gemm-microkernel), [trsm](KernelsHowTo.md#trsm_microkernels), and [gemmtrsm](KernelsHowTo.md#gemmtrsm-microkernels)) takes as its last argument a pointer of type `auxinfo_t`. This BLIS-defined type is defined as a `struct` whose fields contain auxiliary values that may be useful to some microkernel authors, particularly when implementing certain optimization techniques. BLIS provides kernel authors access to the fields of the `auxinfo_t` object via the following static inline functions. Each function takes a single argument, the `auxinfo_t` pointer, and returns one of the values stored within the object. * `bli_auxinfo_next_a()`. Returns the address (`void*`) of the micropanel of `A` that will be used the next time the microkernel will be called. * `bli_auxinfo_next_b()`. Returns the address (`void*`) of the micropanel of `B` that will be used the next time the microkernel will be called. * `bli_auxinfo_ps_a()`. Returns the panel stride (`inc_t`) of the current micropanel of `A`. * `bli_auxinfo_ps_b()`. Returns the panel stride (`inc_t`) of the current micropanel of `B`. -The addresses of the next micropanels of `A` and `B` may be used by the microkernel to perform prefetching, if prefetching is supported by the architecture. Similarly, it may be useful to know the precise distance in memory to the next micropanel. (Note that sometimes the next micropanel to be used is **not** the same as the next micropanel in memory.) +The addresses of the next micropanels of `A` and `B` may be used by the microkernel to perform prefetching, if prefetching is supported by the architecture. Similarly, it may be useful to know the precise distance in memory to the next micropanel. (Note that occasionally the next micropanel to be used is **not** the same as the next micropanel in memory.) Any and all of these values may be safely ignored; they are completely optional. However, BLIS guarantees that all values accessed via the macros listed above will **always** be initialized and meaningful, for every invocation of each microkernel (`gemm`, `trsm`, and `gemmtrsm`). @@ -348,8 +357,7 @@ Any and all of these values may be safely ignored; they are completely optional. An example implementation of the `gemm` microkernel may be found in the `template` configuration directory in: * [config/template/kernels/3/bli\_gemm_opt\_mxn.c](https://github.com/flame/blis/tree/master/config/template/kernels/3/bli_gemm_opt_mxn.c) - -Note that this implementation is coded in C99 and lacks several kinds of optimization that are typical of real-world optimized microkernels, such as vector instructions (or intrinsics) and loop unrolling in _MR_ or _NR_. It is meant to serve only as a starting point for a microkernel developer. +Note that this implementation is coded in C99 and lacks several kinds of optimization that are typical of real-world optimized microkernels, such as vector instructions (or intrinsics) and loop unrolling in the _m_ or _n_ dimensions. It is meant to serve only as a starting point for a microkernel developer. @@ -411,6 +419,8 @@ where `A11` is _MR x MR_ and lower (`trsm_l`) or upper (`trsm_u`) triangular, `B _MR_ and _NR_ are the register blocksizes associated with the microkernel. They are chosen by the developer when the microkernel is written and then encoded into a BLIS configuration, which will reference the microkernel when the BLIS framework is instantiated into a library. For more information on setting register blocksizes and related constants, please see the [BLIS developer configuration guide](ConfigurationHowTo.md). +**Note:** Although the `gemm` microkernel must handle edge-cases, and therefore must take _m_ and _n_ parameters, the `trsm` microkernels are simpler in that they still assume _m = MR_ and _n = NR_, and therefore do not need these _m_ and _n_ parameters passed in. + Parameters: * `a11`: The address of `A11`, which is the _MR x MR_ lower (`trsm_l`) or upper (`trsm_u`) triangular submatrix within the packed micropanel of matrix `A`. `A11` is stored by columns with leading dimension _PACKMR_, where typically _PACKMR_ = _MR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKMR_.) Note that `A11` contains elements in both triangles, though elements in the unstored triangle are not guaranteed to be zero and thus should not be referenced. @@ -454,6 +464,8 @@ Note that these implementations are coded in C99 and lack several kinds of optim ```c void bli_?gemmtrsm_l_ ( + dim_t m, + dim_t n, dim_t k, ctype* restrict alpha, ctype* restrict a10, @@ -467,6 +479,8 @@ void bli_?gemmtrsm_l_ void bli_?gemmtrsm_u_ ( + dim_t m, + dim_t n, dim_t k, ctype* restrict alpha, ctype* restrict a12, @@ -484,6 +498,8 @@ where `` is implementation-dependent. (Recall that the precise ` ```c void bli_?gemmtrsm_l_ukernel ( + dim_t m, + dim_t n, dim_t k, ctype* restrict alpha, ctype* restrict a10, @@ -497,6 +513,8 @@ void bli_?gemmtrsm_l_ukernel void bli_?gemmtrsm_u_ukernel ( + dim_t m, + dim_t n, dim_t k, ctype* restrict alpha, ctype* restrict a12, @@ -517,7 +535,7 @@ The `gemmtrsm_l` microkernel performs the following compound operation: C11 := B11 ``` -where `A11` is _MR_ x _MR_ and lower triangular, `A10` is _MR_ x _k_, and `B01` is _k_ x _NR_. +where `A11` is _MR x MR_ and lower triangular, `A10` is _MR x k_, and `B01` is _k x NR_. The `gemmtrsm_u` microkernel performs: ``` @@ -526,20 +544,22 @@ The `gemmtrsm_u` microkernel performs: C11 := B11 ``` -where `A11` is _MR_ x _MR_ and upper triangular, `A12` is _MR_ x _k_, and `B21` is _k_ x _NR_. -In both cases, `B11` is _MR_ x _NR_ and `alpha` is a scalar. Here, `inv()` denotes matrix inverse. +where `A11` is _MR x MR_ and upper triangular, `A12` is _MR x k_, and `B21` is _k x NR_. +In both cases, `B11` is _MR x NR_ and `alpha` is a scalar. However, `C11` is _m x n_, and therefore the `C11 := B11` statements amount to a copy of only the top-leftmost _m x n_ elements of `B11`. (Recall that A11 and B11 are packed and therefore guaranteed to reside within fully-sized micropanels, whereas `C11` exists in the caller-provided output matrix and may represent a bottom-right edge case.) Here, `inv()` denotes matrix inverse. _MR_ and _NR_ are the register blocksizes associated with the microkernel. They are chosen by the developer when the microkernel is written and then encoded into a BLIS configuration, which will reference the microkernel when the BLIS framework is instantiated into a library. For more information on setting register blocksizes and related constants, please see the [BLIS developer configuration guide](ConfigurationHowTo.md). Parameters: + * `m`: The number of rows of `C11`. + * `n`: The number of columns of `C11`. * `k`: The number of columns of `A10` and rows of `B01` (`trsm_l`); the number of columns of `A12` and rows of `B21` (`trsm_u`). * `alpha`: The address of a scalar to be applied to `B11`. * `a10`, `a12`: The address of `A10` or `A12`, which is the _MR x k_ submatrix of the packed micropanel of `A` that is situated to the left (`trsm_l`) or right (`trsm_u`) of the _MR x MR_ triangular submatrix `A11`. `A10` and `A12` are stored by columns with leading dimension _PACKMR_, where typically _PACKMR_ = _MR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKMR_.) * `a11`: The address of `A11`, which is the _MR x MR_ lower (`trsm_l`) or upper (`trsm_u`) triangular submatrix within the packed micropanel of matrix `A` that is situated to the right of `A10` (`trsm_l`) or the left of `A12` (`trsm_u`). `A11` is stored by columns with leading dimension _PACKMR_, where typically _PACKMR_ = _MR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKMR_.) Note that `A11` contains elements in both triangles, though elements in the unstored triangle are not guaranteed to be zero and thus should not be referenced. * `b01`, `b21`: The address of `B01` and `B21`, which is the _k x NR_ submatrix of the packed micropanel of `B` that is situated above (`trsm_l`) or below (`trsm_u`) the _MR x NR_ block `B11`. `B01` and `B21` are stored by rows with leading dimension _PACKNR_, where typically _PACKNR_ = _NR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKNR_.) * `b11`: The address of `B11`, which is the _MR x NR_ submatrix of the packed micropanel of `B`, situated below `B01` (`trsm_l`) or above `B21` (`trsm_u`). `B11` is stored by rows with leading dimension _PACKNR_, where typically _PACKNR_ = _NR_. (See [Implementation Notes for gemm](KernelsHowTo.md#implementation-notes-for-gemm) for a discussion of _PACKNR_.) - * `c11`: The address of `C11`, which is an _MR x NR_ submatrix of matrix `C`, stored according to `rsc` and `csc`. `C11` is the submatrix within `C` that corresponds to the elements which were packed into `B11`. Thus, `C` is the original input matrix `B` to the overall `trsm` operation. + * `c11`: The address of `C11`, which is an _m x n_ submatrix of matrix `C`, stored according to `rsc` and `csc`, where _m <= MR_ and _n <= NR_. `C11` is the submatrix within `C` that corresponds to the elements which were packed into `B11`. Thus, `C` is the original input matrix `B` to the overall `trsm` operation. * `rsc`: The row stride of matrix `C11` (ie: the distance to the next row, in units of matrix elements). * `csc`: The column stride of matrix `C11` (ie: the distance to the next column, in units of matrix elements). * `data`: The address of an `auxinfo_t` object that contains auxiliary information that may be useful when optimizing the `gemmtrsm` microkernel implementation. (See [Using the auxinfo\_t object](KernelsHowTo.md#Using_the_auxinfo_t_object) for a discussion of the kinds of values available via `auxinfo_t`, and also [Implementation Notes for gemmtrsm](KernelsHowTo.md#implementation-notes-for-gemmtrsm) for caveats.) @@ -690,7 +710,7 @@ This kernel performs the following operation: ``` y := y + alpha * conja(a) * conjy(x) ``` -where `a` is an _m_ x _b_ matrix, `x` is a vector of length _b_, and `y` is a vector of length _m_. Vectors `x` and `y` are stored with strides `incx` and `incy`, respectively. Matrix `a` is stored with row stride `inca` and column stride `lda`, though `inca` is most often (in practice) unit. This kernel is typically implemented as a fused series of _b_ `axpyv` operations updating the same vector `y` (with the elements of `x` serving as the scalars and the columns of `a` serving as the vectors to be scaled). +where `a` is an _m x b_ matrix, `x` is a vector of length _b_, and `y` is a vector of length _m_. Vectors `x` and `y` are stored with strides `incx` and `incy`, respectively. Matrix `a` is stored with row stride `inca` and column stride `lda`, though `inca` is most often (in practice) unit. This kernel is typically implemented as a fused series of _b_ `axpyv` operations updating the same vector `y` (with the elements of `x` serving as the scalars and the columns of `a` serving as the vectors to be scaled). --- @@ -714,7 +734,7 @@ This kernel performs the following operation: ``` y := beta * y + alpha * conjat(a)^T conjx(x) ``` -where `a` is an _m_ x _b_ matrix, where `w` is a vector of length _m_, `y` is a vector of length _b_, and `alpha` is a scalar. +where `a` is an _m x b_ matrix, where `w` is a vector of length _m_, `y` is a vector of length _b_, and `alpha` is a scalar. Vectors `x` and `y` are stored with strides `incx` and `incy`, respectively. Matrix `a` is stored with row stride `inca` and column stride `lda`, though `inca` is most often (in practice) unit. This kernel is typically implemented as a series of _b_ `dotxv` operations with the same right-hand operand vector `x` (contracted with the rows of `a^T` and accumulating to the corresponding elements of vector `y`). @@ -745,7 +765,7 @@ This kernel performs the following operation: y := beta * y + alpha * conjat(a)^T conjw(w) z := z + alpha * conja(a) conjx(x) ``` -where `a` is an _m_ x _b_ matrix, `w` and `z` are vectors of length _m_, `x` and `y` are vectors of length _b_, and `alpha` and `beta` are scalars. +where `a` is an _m x b_ matrix, `w` and `z` are vectors of length _m_, `x` and `y` are vectors of length _b_, and `alpha` and `beta` are scalars. Vectors `w`, `z`, `x` and `y` are stored with strides `incw`, `incz`, `incx`, and `incy`, respectively. Matrix `a` is stored with row stride `inca` and column stride `lda`, though `inca` is most often (in practice) unit. This kernel is typically implemented as a series of _b_ `dotxv` operations with the same right-hand operand vector `w` fused with a series of _b_ `axpyv` operations updating the same vector `z`. diff --git a/docs/MixedDatatypes.md b/docs/MixedDatatypes.md index ee109f5a10..7a67cfad8f 100644 --- a/docs/MixedDatatypes.md +++ b/docs/MixedDatatypes.md @@ -14,7 +14,12 @@ This document serves as a guide to users interested in taking advantage of BLIS's support for performing the `gemm` operation on operands of differing -datatypes (domain and/or precision). +datatypes (domain and/or precision). For further details on the implementation +present in BLIS, please see the latest draft of our paper +"Supporting Mixed-domain Mixed-precision Matrix Multiplication +within the BLIS Framework" available in the +[Citations section](https://github.com/flame/blis/#citations) +of the main [BLIS webpage](https://github.com/flame/blis). ## Categories of mixed datatypes diff --git a/docs/Multithreading.md b/docs/Multithreading.md index 7fff7357f3..48fbc8ca16 100644 --- a/docs/Multithreading.md +++ b/docs/Multithreading.md @@ -23,11 +23,17 @@ # Introduction -Our paper [Anatomy of High-Performance Many-Threaded Matrix Multiplication](https://github.com/flame/blis#citations), presented at IPDPS'14, identified 5 loops around the microkernel as opportunities for parallelization within level-3 operations such as `gemm`. Within BLIS, we have enabled parallelism for 4 of those loops and have extended it to the rest of the level-3 operations except for `trsm`. +Our paper [Anatomy of High-Performance Many-Threaded Matrix Multiplication](https://github.com/flame/blis#citations), presented at IPDPS'14, identified five loops around the microkernel as opportunities for parallelization within level-3 operations such as `gemm`. Within BLIS, we have enabled parallelism for four of those loops, with the fifth planned for future work. This software architecture extends naturally to all level-3 operations except for `trsm`, where its application is necessarily limited to three of the five loops due to inter-iteration dependencies. + +**IMPORTANT**: Multithreading in BLIS is disabled by default. Furthermore, even when multithreading is enabled, BLIS will default to single-threaded execution at runtime. In order to both *allow* and *invoke* parallelism from within BLIS operations, you must both *enable* multithreading at configure-time and *specify* multithreading at runtime. + +To summarize: In order to observe multithreaded parallelism within a BLIS operation, you must do *both* of the following: +1. Enable multithreading at configure-time. This is discussed in the [next section](docs/Multithreading.md#enabling-multithreading). +2. Specify multithreading at runtime. This is also discussed [later on](docs/Multithreading.md#specifying-multithreading). # Enabling multithreading -Note that BLIS disables multithreading by default. In order to extract multithreaded parallelism from BLIS, you must first enable multithreading explicitly at configure-time. +BLIS disables multithreading by default. In order to allow multithreaded parallelism from BLIS, you must first enable multithreading explicitly at configure-time. As of this writing, BLIS optionally supports multithreading via either OpenMP or POSIX threads. @@ -96,12 +102,19 @@ There are three broad methods of specifying multithreading in BLIS: * [Globally at runtime](Multithreading.md#globally-at-runtime) * [Locally at runtime](Multithreading.md#locally-at-runtime) (that is, on a per-call, thread-safe basis) -Within these three broad methods there are two specific ways of expressing a request for parallelism. First, the user may express a single number--the total number of threads, or ways of parallelism, to use within a single operation such as `gemm`. We call this the "automatic" way. Alternatively, the user may express the number of ways of parallelism to obtain within *each loop* of the level-3 operation. We call this the "manual" way. The latter way is actually what BLIS eventually needs before it can perform its multithreading; the former is viable only because we have a heuristic of determing a reasonable instance of the latter when given the former. +Within these three broad methods there are two specific ways of expressing a request for parallelism. First, the user may express a single number--the total number of threads, or ways of parallelism, to use within a single operation such as `gemm`. We call this the "automatic" way. Alternatively, the user may express the number of ways of parallelism to obtain within *each loop* of the level-3 operation. We call this the "manual" way. The latter way is actually what BLIS eventually needs before it can perform its multithreading; the former is viable only because we have a heuristic of determining a reasonable instance of the latter when given the former. This pattern--automatic or manual--holds regardless of which of the three methods is used. Regardless of which method is employed, and which specific way within each method, after setting the number of threads, the application may call the desired level-3 operation (via either the [typed API](docs/BLISTypedAPI.md) or the [object API](docs/BLISObjectAPI.md)) and the operation will execute in a multithreaded manner. (When calling BLIS via the BLAS API, only the first two (global) methods are available.) -NOTE: Please be aware of what happens if you try to specify both the automatic and manual ways, as it could otherwise confuse new users. Regardless of which broad method is used, **if multithreading is specified via both the automatic and manual ways, the manual way will always take precedence.** Also, specifying parallelism for even *one* loop counts as specifying the manual way (in which case the ways of parallelism for the remaining loops will be assumed to be 1). +**Note**: Please be aware of what happens if you try to specify both the automatic and manual ways, as it could otherwise confuse new users. Here are the important points: + * Regardless of which broad method is used, **if multithreading is specified via both the automatic and manual ways, the values set via the manual way will always take precedence.** + * Specifying parallelism for even *one* loop counts as specifying the manual way (in which case the ways of parallelism for the remaining loops will be assumed to be 1). And in the case of the environment variable method, setting the ways of parallelism for a loop to 1 counts as specifying parallelism! If you want to switch from using the manual way to automatic way, you must not only set (`export`) the `BLIS_NUM_THREADS` variable, but you must also `unset` all of the `BLIS_*_NT` variables. + * If you have specified multithreading via *both* the automatic and manual ways, BLIS will **not** complain if the values are inconsistent with one another. (For example, you may request 12 total threads be used while also specifying 2 and 4 ways of parallelism within the JC and IC loops, respectively, for a total of 8 ways.) Furthermore, you will be able to query these inconsistent values via the runtime API both before and after multithreading executes. + * If multithreading is disabled, you **may still** specify multithreading values via either the manual or automatic ways. However, BLIS will silently ignore **all** of these values. A BLIS library that is built with multithreading disabled at configure-time will always run sequentially (from the perspective of a single application thread). + +Furthermore: +* For small numbers of threads, the number requested will be honored faithfully. However, if you request a larger number of threads that happens to also be prime, BLIS will reduce the number by one in order to allow more more efficient thread factorizations. This behavior can be overridden by configuring BLIS with the `BLIS_ENABLE_AUTO_PRIME_NUM_THREADS` macro defined in the `bli_family_*.h` file of the relevant subconfiguration. Similarly, the threshold beyond which BLIS will reduce primes by one can be set via `BLIS_NT_MAX_PRIME`. (This latter value is ignored if the former macro is defined.) ## Globally via environment variables @@ -109,6 +122,8 @@ The most common method of specifying multithreading in BLIS is globally via envi Regardless of whether you end up using the automatic or manual way of expressing a request for multithreading, note that the environment variables are read (via `getenv()`) by BLIS **only once**, when the library is initialized. Subsequent to library initialization, the global settings for parallelization may only be changed via the [global runtime API](Multithreading.md#globally-at-runtime). If this constraint is not a problem, then environment variables may work fine for you. Otherwise, please consider [local settings](Multithreading.md#locally-at-runtime). (Local settings may used at any time, regardless of whether global settings were explicitly specified, and local settings always override global settings.) +**Note**: Regardless of which way ([automatic](Multithreading.md#environment-variables-the-automatic-way) or [manual](Multithreading.md#environment-variables-the-manual-way)) environment variables are used to specify multithreading, that specification will affect operation of BLIS through **both** the BLAS compatibility layer as well as the native ([typed](docs/BLISTypedAPI.md) and [object](docs/BLISObjectAPI.md)) APIs that are unique to BLIS. + ### Environment variables: the automatic way The automatic way of specifying parallelism entails simply setting the total number of threads you wish BLIS to employ in its parallelization. This total number of threads is captured by the `BLIS_NUM_THREADS` environment variable. You can set this variable prior to executing your BLIS-linked executable: @@ -119,7 +134,7 @@ $ ./my_blis_program ``` This causes BLIS to automatically determine a reasonable threading strategy based on what is known about the operation and problem size. If `BLIS_NUM_THREADS` is not set, BLIS will attempt to query the value of `OMP_NUM_THREADS`. If neither variable is set, the default number of threads is 1. -**Note:** We *highly* discourage use of the `OMP_NUM_THREADS` environment variable and may remove support for it in the future. If you wish to set parallelism globally via environment variables, please use `BLIS_NUM_THREADS`. +**Note**: We *highly* discourage use of the `OMP_NUM_THREADS` environment variable and may remove support for it in the future. If you wish to set parallelism globally via environment variables, please use `BLIS_NUM_THREADS`. ### Environment variables: the manual way @@ -127,13 +142,13 @@ The manual way of specifying parallelism involves communicating which loops with The below chart describes the five loops used in BLIS's matrix multiplication operations. -| Loop around microkernel | Environment variable | Direction | Notes | -|:-------------------------|:---------------------|:----------|:------------| -| 5th loop | `BLIS_JC_NT` | `n` | | -| 4th loop | _N/A_ | `k` | Not enabled | -| 3rd loop | `BLIS_IC_NT` | `m` | | -| 2nd loop | `BLIS_JR_NT` | `n` | | -| 1st loop | `BLIS_IR_NT` | `m` | | +| Loop around microkernel | Environment variable | Direction | Notes | +|:-------------------------|:---------------------|:----------|:---------------| +| 5th loop | `BLIS_JC_NT` | `n` | | +| 4th loop | _N/A_ | `k` | Not enabled | +| 3rd loop | `BLIS_IC_NT` | `m` | | +| 2nd loop | `BLIS_JR_NT` | `n` | Typically <= 4 | +| 1st loop | `BLIS_IR_NT` | `m` | Typically 1 | **Note**: Parallelization of the 4th loop is not currently enabled because each iteration of the loop updates the same part of the output matrix C. Thus, to safely parallelize it requires either a reduction or mutex locks when updating C. @@ -146,7 +161,7 @@ In general, the way to choose how to set these environment variables is as follo Next, which combinations of loops to parallelize depends on which caches are shared. Here are some of the more common scenarios: * When compute resources have private L3 caches (example: multi-socket systems), try parallelizing the `JC` loop. This means threads (or thread groups) will pack and compute with different row panels from matrix B. * For compute resources that have private L2 caches but that share an L3 cache (example: cores on a socket), try parallelizing the `IC` loop. In this situation, threads will share the same packed row panel from matrix B, but pack and compute with different blocks of matrix A. - * If compute resources share an L2 cache but have private L1 caches (example: pairs of cores), try parallelizing the `JR` loop. Here, threads share the same packed block of matrix A but read different packed micropanels of B into their private L1 caches. In some situations, parallelizing the `IR` loop may also be effective. + * If compute resources share an L2 cache but have private L1 caches (example: pairs of cores), try parallelizing the `JR` loop. Here, threads share the same packed block of matrix A but read different packed micropanels of B into their private L1 caches. In some situations, *lightly* parallelizing the `IR` loop may also be effective. ![The primary algorithm for level-3 operations in BLIS](http://www.cs.utexas.edu/users/field/mm_algorithm_color.png) @@ -154,6 +169,8 @@ Next, which combinations of loops to parallelize depends on which caches are sha If you still wish to set the parallelization scheme globally, but you want to do so at runtime, BLIS provides a thread-safe API for specifying multithreading. Think of these functions as a way to modify the same internal data structure into which the environment variables are read. (Recall that the environment variables are only read once, when BLIS is initialized). +**Note**: Regardless of which way ([automatic](Multithreading.md#globally-at-runtime-the-automatic-way) or [manual](Multithreading.md#globally-at-runtime-the-manual-way)) the global runtime API is used to specify multithreading, that specification will affect operation of BLIS through **both** the BLAS compatibility layer as well as the native ([typed](docs/BLISTypedAPI.md) and [object](docs/BLISObjectAPI.md)) APIs that are unique to BLIS. + ### Globally at runtime: the automatic way If you simply want to specify an overall number of threads and let BLIS choose a thread factorization automatically, use the following function: @@ -193,6 +210,8 @@ In addition to the global methods based on environment variables and runtime fun As with environment variables and the global runtime API, there are two ways to specify parallelism: the automatic way and the manual way. Both ways involve allocating a BLIS-specific object, initializing the object and encoding the desired parallelization, and then passing a pointer to the object into one of the expert interfaces of either the [typed](docs/BLISTypedAPI.md) or [object](docs/BLISObjectAPI) APIs. We provide examples of utilizing this threading object below. +**Note**: Neither way ([automatic](Multithreading.md#locally-at-runtime-the-automatic-way) nor [manual](Multithreading.md#locally-at-runtime-the-manual-way)) of specifying multithreading via the local runtime API can be used via the BLAS interfaces. The local runtime API may *only* be used via the native ([typed](docs/BLISTypedAPI.md) and [object](docs/BLISObjectAPI.md)) APIs, which are unique to BLIS. (Furthermore, the expert interfaces of each API must be used. This is demonstrated later on in this section.) + ### Initializing a rntm_t Before specifying the parallelism (automatically or manually), you must first allocate a special BLIS object called a `rntm_t` (runtime). The object is quite small (about 64 bytes), and so we recommend allocating it statically on the function stack: @@ -210,7 +229,7 @@ bli_rntm_init( &rntm ); ``` As of this writing, BLIS treats a default-initialized `rntm_t` as a request for single-threaded execution. -**Note**: If you choose to **not** initialize the `rntm_t` object, you **must** set its parallelism via either the automatic way or the manual way, described below. Passing a completely uninitialized `rntm_t` to a level-3 operation **will almost surely result in undefined behvaior!** +**Note**: If you choose to **not** initialize the `rntm_t` object, you **must** set its parallelism via either the automatic way or the manual way, described below. Passing a completely uninitialized `rntm_t` to a level-3 operation **will almost surely result in undefined behavior!** ### Locally at runtime: the automatic way @@ -273,6 +292,8 @@ Also, you may pass in `NULL` for the `rntm_t*` parameter of an expert interface. This situation could lead to unexpectedly low multithreaded performance. Suppose the user calls `gemm` on a problem with a large m dimension and small k and n dimensions, and explicitly requests parallelism only in the IC loop, but also suppose that the storage of C does not match that of the microkernel's preference. After BLIS transposes the operation internally, the *effective* m dimension will no longer be large; instead, it will be small (because the original m and n dimension will have been swapped). The multithreaded implementation will then proceed to parallelize this small m dimension. There are currently no good *and* easy solutions to this problem. Eventually, though, we plan to add support for two microkernels per datatype per configuration--one for use with matrices C that are row-stored, and one for those that are column-stored. This will obviate the logic within BLIS that sometimes induces the operation transposition, and the problem will go away. + +* **Thread affinity when BLIS and MKL are used together.** Some users have reported that when running a program that links both BLIS (configured with OpenMP) and MKL, **and** when OpenMP thread affinity has been specified (e.g. via `OMP_PROC_BIND` and `OMP_PLACES`), that very poor performance is observed. This may be due to incorrect thread masking in this case, causing all threads to run on one physical core. The exact circumstances leading to this behavior have not been identified, but unsetting the OpenMP thread affinity variables appears to be a solution. # Conclusion diff --git a/docs/Performance.md b/docs/Performance.md new file mode 100644 index 0000000000..f4992d1dee --- /dev/null +++ b/docs/Performance.md @@ -0,0 +1,664 @@ +# Contents + +* **[Contents](Performance.md#contents)** +* **[Introduction](Performance.md#introduction)** +* **[General information](Performance.md#general-information)** +* **[Interpretation](Performance.md#interpretation)** +* **[Reproduction](Performance.md#reproduction)** +* **[Level-3 performance](Performance.md#level-3-performance)** + * **[ThunderX2](Performance.md#thunderx2)** + * **[Experiment details](Performance.md#thunderx2-experiment-details)** + * **[Results](Performance.md#thunderx2-results)** + * **[SkylakeX](Performance.md#skylakex)** + * **[Experiment details](Performance.md#skylakex-experiment-details)** + * **[Results](Performance.md#skylakex-results)** + * **[Haswell](Performance.md#haswell)** + * **[Experiment details](Performance.md#haswell-experiment-details)** + * **[Results](Performance.md#haswell-results)** + * **[Zen](Performance.md#zen)** + * **[Experiment details](Performance.md#zen-experiment-details)** + * **[Results](Performance.md#zen-results)** + * **[Zen2](Performance.md#zen2)** + * **[Experiment details](Performance.md#zen2-experiment-details)** + * **[Results](Performance.md#zen2-results)** + * **[A64fx](Performance.md#a64fx)** + * **[Experiment details](Performance.md#a64fx-experiment-details)** + * **[Results](Performance.md#a64fx-results)** + * **[Neoverse N1](Performance.md#neoverse-n1)** + * **[Experiment details](Performance.md#neoverse-n1-experiment-details)** + * **[Results](Performance.md#neoverse-n1-results)** +* **[Feedback](Performance.md#feedback)** + +# Introduction + +This document showcases performance results for a representative sample of +level-3 operations on large matrices with BLIS and BLAS for several hardware +architectures. + +# General information + +Generally speaking, for level-3 operations on large matrices, we publish three +"panels" for each type of hardware, +each of which reports one of: single-threaded performance, multithreaded +performance on a single socket, or multithreaded performance on two sockets. +Each panel will consist of a 4x5 grid of graphs, with each row representing +a different datatype (single real, double real, single complex, and double +complex) and each column representing a different operation (`gemm`, +`hemm`/`symm`, `herk`/`syrk`, `trmm`, and `trsm`). +Each of the 20 graphs within a panel will contain an x-axis that reports +problem size, with all matrix dimensions equal to the problem size (e.g. +_m_ = _n_ = _k_), resulting in square matrices. +The y-axis will report in units GFLOPS (billions of floating-point operations +per second) in the case of single-threaded performance, or GFLOPS/core in the +case of single- or dual-socket multithreaded performance, where GFLOPS/core +is simply the total GFLOPS observed divided by the number of threads utilized. +This normalization is done intentionally in order to facilitate a visual +assessment of the drop in efficiency of multithreaded performance relative +to their single-threaded baselines. + +It's also worth pointing out that the top of each graph (e.g. the maximum +y-axis value depicted) _always_ corresponds to the theoretical peak performance +under the conditions associated with that graph. +Theoretical peak performance, in units of GFLOPS/core, is calculated as the +product of: +1. the maximum sustainable clock rate in GHz; and +2. the maximum number of floating-point operations (flops) that can be +executed per cycle (per core). + +Note that the maximum sustainable clock rate may change depending on the +conditions. +For example, on some systems the maximum clock rate is higher when only one +core is active (e.g. single-threaded performance) versus when all cores are +active (e.g. multithreaded performance). +The maximum number of flops executable per cycle (per core) is generally +computed as the product of: +1. the maximum number of fused multiply-add (FMA) vector instructions that +can be issued per cycle (per core); +2. the maximum number of elements that can be stored within a single vector +register (for the datatype in question); and +3. 2.0, since an FMA instruction fuses two operations (a multiply and an add). + +The problem size range, represented on the x-axis, is usually sampled with 50 +equally-spaced problem size. +For example, for single-threaded execution, we might choose to execute with +problem sizes of 48 to 2400 in increments of 48, or 56 to 2800 in increments +of 56. +These values are almost never chosen for any particular (read: sneaky) reason; +rather, we start with a "good" maximum problem size, such as 2400 or 2800, and +then divide it by 50 to obtain the appropriate starting point and increment. + +Finally, each point along each curve represents the best of three trials. + +# Interpretation + +In general, the the curves associated with higher-performing implementations +will appear higher in the graphs than lower-performing implementations. +Ideally, an implementation will climb in performance (as a function of problem +size) as quickly as possible and asymptotically approach some high fraction of +peak performance. + +Occasionally, we may publish graphs with incomplete curves--for example, +only the first 25 data points in a typical 50-point series--usually because +the implementation being tested was slow enough that it was not practical to +allow it to finish. + +Where along the x-axis you focus your attention will depend on the segment of +the problem size range that you care about most. Some people's applications +depend heavily on smaller problems, where "small" can mean anything from 10 +to 1000 or even higher. Some people consider 1000 to be quite large, while +others insist that 5000 is merely "medium." What each of us considers to be +small, medium, or large (naturally) depends heavily on the kinds of dense +linear algebra problems we tend to encounter. No one is "right" or "wrong" +about their characterization of matrix smallness or bigness since each person's +relative frame of reference can vary greatly. That said, the +[Science of High-Performance Computing](http://shpc.ices.utexas.edu/) group at +[The University of Texas at Austin](https://www.utexas.edu/) tends to target +matrices that it classifies as "medium-to-large", and so most of the graphs +presented in this document will reflect that targeting in their x-axis range. + +When corresponding with us, via email or when opening an +[issue](https://github.com/flame/blis/issues) on github, we kindly ask that +you specify as closely as possible (though a range is fine) your problem +size of interest so that we can better assist you. + +# Reproduction + +In general, we do not offer any step-by-step guide for how to reproduce the +performance graphs shown below. + +That said, if you are keenly interested in running your own performance +benchmarks, either in an attempt to reproduce the results shown here or to +measure performance of different hardware, of different implementations (or +versions), and/or for different problem sizes, you should begin by studying +the source code, `Makefile`, and scripts in +the [test/3](https://github.com/flame/blis/tree/master/test/3) directory +of the BLIS source distribution. Then, you'll need to take time to build +and/or install some (or all) of the implementations shown (e.g. +[OpenBLAS](https://github.com/xianyi/OpenBLAS), +[MKL](https://software.intel.com/en-us/mkl/), and +[Eigen](http://eigen.tuxfamily.org), including BLIS. Be sure to consult +the detailed notes provided below; they should be *very* helpful in successfully +building the libraries. The `runme.sh` script in `test/3` will help you run +some (or all) of the test drivers produced by the `Makefile`, and the +Matlab/Octave function `plot_panel_4x5()` defined in the `matlab` directory +will help you turn the output of those test drivers into a PDF file of graphs. +The `runthese.m` file will contain example invocations of the function. + +# Level-3 performance + +## ThunderX2 + +### ThunderX2 experiment details + +* Location: Unknown +* Processor model: Marvell ThunderX2 CN9975 +* Core topology: two sockets, 28 cores per socket, 56 cores total +* SMT status: disabled at boot-time +* Max clock rate: 2.2GHz (single-core and multicore) +* Max vector register length: 128 bits (NEON) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 17.6 GFLOPS (double-precision), 35.2 GFLOPS (single-precision) + * multicore: 17.6 GFLOPS/core (double-precision), 35.2 GFLOPS/core (single-precision) +* Operating system: Ubuntu 16.04 (Linux kernel 4.15.0) +* Page size: unknown +* Compiler: gcc 7.3.0 +* Results gathered: 14 February 2019 +* Implementations tested: + * BLIS 075143df (0.5.1-39) + * configured with `./configure -t openmp thunderx2` (single- and multithreaded) + * sub-configuration exercised: `thunderx2` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (28 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=7` + * Multithreaded (56 core) execution requested via `export BLIS_JC_NT=8 BLIS_IC_NT=7` + * OpenBLAS 52d3f7a + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=56` (multithreaded, 56 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (28 core) execution requested via `export OPENBLAS_NUM_THREADS=28` + * Multithreaded (56 core) execution requested via `export OPENBLAS_NUM_THREADS=56` + * ARMPL 18.4 + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (28 core) execution requested via `export OMP_NUM_THREADS=28` + * Multithreaded (56 core) execution requested via `export OMP_NUM_THREADS=56` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 55"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * ARMPL performance is remarkably uneven across datatypes and operations, though it would appear their "base" consists of OpenBLAS, which they then optimize for select, targeted routines. Unfortunately, we were unable to test the absolute latest versions of OpenBLAS and ARMPL on this hardware before we lost access. We will rerun these experiments once we gain access to a similar system. + +### ThunderX2 results + +#### pdf + +* [ThunderX2 single-threaded](graphs/large/l3_perf_tx2_nt1.pdf) +* [ThunderX2 multithreaded (28 cores)](graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf) +* [ThunderX2 multithreaded (56 cores)](graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf) + +#### png (inline) + +* **ThunderX2 single-threaded** +![single-threaded](graphs/large/l3_perf_tx2_nt1.png) +* **ThunderX2 multithreaded (28 cores)** +![multithreaded (28 cores)](graphs/large/l3_perf_tx2_jc4ic7_nt28.png) +* **ThunderX2 multithreaded (56 cores)** +![multithreaded (56 cores)](graphs/large/l3_perf_tx2_jc8ic7_nt56.png) + +--- + +## SkylakeX + +### SkylakeX experiment details + +* Location: Oracle cloud +* Processor model: Intel Xeon Platinum 8167M (SkylakeX/AVX-512) +* Core topology: two sockets, 26 cores per socket, 52 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 2.0GHz (single-core and multicore) +* Max vector register length: 512 bits (AVX-512) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 64 GFLOPS (double-precision), 128 GFLOPS (single-precision) + * multicore: 64 GFLOPS/core (double-precision), 128 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Page size: 4096 bytes +* Compiler: gcc 7.3.0 +* Results gathered: 6 March 2019, 27 March 2019 +* Implementations tested: + * BLIS 9f1dbe5 (0.5.1-54) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `skx` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (26 core) execution requested via `export BLIS_JC_NT=2 BLIS_IC_NT=13` + * Multithreaded (52 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=13` + * OpenBLAS 0.3.5 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=52` (multithreaded, 52 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (26 core) execution requested via `export OPENBLAS_NUM_THREADS=26` + * Multithreaded (52 core) execution requested via `export OPENBLAS_NUM_THREADS=52` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (March 27, 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 67. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (26 core) execution requested via `export OMP_NUM_THREADS=26` + * Multithreaded (52 core) execution requested via `export OMP_NUM_THREADS=52` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2019 update 1 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (26 core) execution requested via `export MKL_NUM_THREADS=26` + * Multithreaded (52 core) execution requested via `export MKL_NUM_THREADS=52` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 51"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits: 1.0GHz - 2.0GHz + * Adjusted minimum: 2.0GHz +* Comments: + * MKL yields superb performance for most operations, though BLIS is not far behind except for `trsm`. (We understand the `trsm` underperformance and hope to address it in the future.) OpenBLAS lags far behind MKL and BLIS due to lack of full support for AVX-512, and possibly other reasons related to software architecture and register/cache blocksizes. + +### SkylakeX results + +#### pdf + +* [SkylakeX single-threaded](graphs/large/l3_perf_skx_nt1.pdf) +* [SkylakeX multithreaded (26 cores)](graphs/large/l3_perf_skx_jc2ic13_nt26.pdf) +* [SkylakeX multithreaded (52 cores)](graphs/large/l3_perf_skx_jc4ic13_nt52.pdf) + +#### png (inline) + +* **SkylakeX single-threaded** +![single-threaded](graphs/large/l3_perf_skx_nt1.png) +* **SkylakeX multithreaded (26 cores)** +![multithreaded (26 cores)](graphs/large/l3_perf_skx_jc2ic13_nt26.png) +* **SkylakeX multithreaded (52 cores)** +![multithreaded (52 cores)](graphs/large/l3_perf_skx_jc4ic13_nt52.png) + +--- + +## Haswell + +### Haswell experiment details + +* Location: TACC (Lonestar5) +* Processor model: Intel Xeon E5-2690 v3 (Haswell) +* Core topology: two sockets, 12 cores per socket, 24 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.5GHz (single-core), 3.1GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 56 GFLOPS (double-precision), 112 GFLOPS (single-precision) + * multicore: 49.6 GFLOPS/core (double-precision), 99.2 GFLOPS/core (single-precision) +* Operating system: Cray Linux Environment 6 (Linux kernel 4.4.103) +* Page size: 4096 bytes +* Compiler: gcc 6.3.0 +* Results gathered: 25-26 February 2019, 27 March 2019 +* Implementations tested: + * BLIS 075143df (0.5.1-39) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `haswell` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (12 core) execution requested via `export BLIS_JC_NT=2 BLIS_IC_NT=3 BLIS_JR_NT=2` + * Multithreaded (24 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=3 BLIS_JR_NT=2` + * OpenBLAS 0.3.5 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=24` (multithreaded, 24 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export OPENBLAS_NUM_THREADS=12` + * Multithreaded (24 core) execution requested via `export OPENBLAS_NUM_THREADS=24` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (March 27, 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 67. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12` + * Multithreaded (24 core) execution requested via `export OMP_NUM_THREADS=24` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2018 update 2 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export MKL_NUM_THREADS=12` + * Multithreaded (24 core) execution requested via `export MKL_NUM_THREADS=24` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 23"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * We were pleasantly surprised by how competitive BLIS performs relative to MKL on this multicore Haswell system, which is a _very_ common microarchitecture, and _very_ similar to the more recent Broadwells, Skylakes (desktop), Kaby Lakes, and Coffee Lakes that succeeded it. + +### Haswell results + +#### pdf + +* [Haswell single-threaded](graphs/large/l3_perf_has_nt1.pdf) +* [Haswell multithreaded (12 cores)](graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf) +* [Haswell multithreaded (24 cores)](graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf) + +#### png (inline) + +* **Haswell single-threaded** +![single-threaded](graphs/large/l3_perf_has_nt1.png) +* **Haswell multithreaded (12 cores)** +![multithreaded (12 cores)](graphs/large/l3_perf_has_jc2ic3jr2_nt12.png) +* **Haswell multithreaded (24 cores)** +![multithreaded (24 cores)](graphs/large/l3_perf_has_jc4ic3jr2_nt24.png) + +--- + +## Zen + +### Zen experiment details + +* Location: Oracle cloud +* Processor model: AMD Epyc 7551 (Zen1 "Naples") +* Core topology: two sockets, 4 dies per socket, 2 core complexes (CCX) per die, 4 cores per CCX, 64 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.0GHz (single-core), 2.55GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 1 + * Alternatively, FMA vector IPC is 2 when vectors are limited to 128 bits each. +* Peak performance: + * single-core: 24 GFLOPS (double-precision), 48 GFLOPS (single-precision) + * multicore: 20.4 GFLOPS/core (double-precision), 40.8 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Page size: 4096 bytes +* Compiler: gcc 7.3.0 +* Results gathered: 6 March 2019, 19 March 2019, 27 March 2019 +* Implementations tested: + * BLIS 9f1dbe5 (0.5.1-54) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `zen` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (32 core) execution requested via `export BLIS_JC_NT=1 BLIS_IC_NT=8 BLIS_JR_NT=4` + * Multithreaded (64 core) execution requested via `export BLIS_JC_NT=2 BLIS_IC_NT=8 BLIS_JR_NT=4` + * OpenBLAS 0.3.5 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=64` (multithreaded, 64 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (32 core) execution requested via `export OPENBLAS_NUM_THREADS=32` + * Multithreaded (64 core) execution requested via `export OPENBLAS_NUM_THREADS=64` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (March 27, 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 67. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (32 core) execution requested via `export OMP_NUM_THREADS=32` + * Multithreaded (64 core) execution requested via `export OMP_NUM_THREADS=64` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2019 update 1 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (32 core) execution requested via `export MKL_NUM_THREADS=32` + * Multithreaded (64 core) execution requested via `export MKL_NUM_THREADS=64` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 63"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits: 1.2GHz - 2.0GHz + * Adjusted minimum: 2.0GHz +* Comments: + * MKL performance is dismal, despite being linked in the same manner as on the Xeon Platinum. It's not clear what is causing the slowdown. It could be that MKL's runtime kernel/blocksize selection logic is falling back to some older, more basic implementation because CPUID is not returning Intel as the hardware vendor. Alternatively, it's possible that MKL is trying to use kernels for the closest Intel architectures--say, Haswell/Broadwell--but its implementations use Haswell-specific optimizations that, due to microarchitectural differences, degrade performance on Zen. + +### Zen results + +#### pdf + +* [Zen single-threaded](graphs/large/l3_perf_zen_nt1.pdf) +* [Zen multithreaded (32 cores)](graphs/large/l3_perf_zen_jc1ic8jr4_nt32.pdf) +* [Zen multithreaded (64 cores)](graphs/large/l3_perf_zen_jc2ic8jr4_nt64.pdf) + +#### png (inline) + +* **Zen single-threaded** +![single-threaded](graphs/large/l3_perf_zen_nt1.png) +* **Zen multithreaded (32 cores)** +![multithreaded (32 cores)](graphs/large/l3_perf_zen_jc1ic8jr4_nt32.png) +* **Zen multithreaded (64 cores)** +![multithreaded (64 cores)](graphs/large/l3_perf_zen_jc2ic8jr4_nt64.png) + +--- + +## Zen2 + +### Zen2 experiment details + +* Location: Oracle cloud +* Processor model: AMD Epyc 7742 (Zen2 "Rome") +* Core topology: two sockets, 8 Core Complex Dies (CCDs) per socket, 2 Core Complexes (CCX) per CCD, 4 cores per CCX, 128 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 2.25GHz (base, documented); 3.4GHz boost (single-core, documented); 2.6GHz boost (multicore, estimated) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 + * Alternatively, FMA vector IPC is 4 when vectors are limited to 128 bits each. +* Peak performance: + * single-core: 54.4 GFLOPS (double-precision), 108.8 GFLOPS (single-precision) + * multicore (estimated): 41.6 GFLOPS/core (double-precision), 83.2 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Page size: 4096 bytes +* Compiler: gcc 9.3.0 +* Results gathered: 24 September 2020, 29 September 2020 +* Implementations tested: + * BLIS 4fd8d9f (0.7.0-55) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `zen2` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (64 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=4 BLIS_JR_NT=4` + * Multithreaded (128 core) execution requested via `export BLIS_JC_NT=8 BLIS_IC_NT=4 BLIS_JR_NT=4` + * OpenBLAS 0.3.10 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=64` (multithreaded, 64 cores) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=128` (multithreaded, 128 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (64 core) execution requested via `export OPENBLAS_NUM_THREADS=64` + * Multithreaded (128 core) execution requested via `export OPENBLAS_NUM_THREADS=128` + * Eigen 3.3.90 + * Obtained via the [Eigen GitLab homepage](https://gitlab.com/libeigen/eigen) (24 September 2020) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 60. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (64 core) execution requested via `export OMP_NUM_THREADS=64` + * Multithreaded (128 core) execution requested via `export OMP_NUM_THREADS=128` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2020 update 3 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (64 core) execution requested via `export MKL_NUM_THREADS=64` + * Multithreaded (128 core) execution requested via `export MKL_NUM_THREADS=128` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-127"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. + * All executables were run through `numactl --interleave=all`. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits (steps): 1.5GHz, 2.0GHz, 2.25GHz + * Adjusted minimum: 2.25GHz +* Comments: + * MKL performance is once again underwhelming. This is likely because Intel has decided that it does not want to give users of MKL a reason to purchase AMD hardware. + +### Zen2 results + +#### pdf + +* [Zen2 single-threaded](graphs/large/l3_perf_zen2_nt1.pdf) +* [Zen2 multithreaded (64 cores)](graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.pdf) +* [Zen2 multithreaded (128 cores)](graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.pdf) + +#### png (inline) + +* **Zen2 single-threaded** +![single-threaded](graphs/large/l3_perf_zen2_nt1.png) +* **Zen2 multithreaded (64 cores)** +![multithreaded (64 cores)](graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.png) +* **Zen2 multithreaded (128 cores)** +![multithreaded (128 cores)](graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.png) + +--- + +## A64fx + +### A64fx experiment details + +* Location: RIKEN Center of Computational Science in Kobe, Japan + * These test results were gathered on the Fugaku supercomputer under project "é‡å­ç‰©è³ªã®å‰µç™ºã¨æ©Ÿèƒ½ã®ãŸã‚ã®åŸºç¤Žç§‘å­¦ ―「富岳ã€ã¨æœ€å…ˆç«¯å®Ÿé¨“ã®å¯†é€£æºã«ã‚ˆã‚‹é©æ–°çš„強相関電å­ç§‘å­¦" (hp200132) (Basic Science for Emergence and Functionality in Quantum Matter: Innovative Strongly-Correlated Electron Science by Integration of "Fugaku" and Frontier Experiments) +* Processor model: Fujitsu A64fx +* Core topology: one socket, 4 NUMA groups per socket, 13 cores per group (one reserved for the OS), 48 cores total +* SMT status: Unknown +* Max clock rate: 2.2GHz (single- and multicore, observed) +* Max vector register length: 512 bits (SVE) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 70.4 GFLOPS (double-precision), 140.8 GFLOPS (single-precision) + * multicore: 70.4 GFLOPS/core (double-precision), 140.8 GFLOPS/core (single-precision) +* Operating system: RHEL 8.3 +* Page size: 256 bytes +* Compiler: gcc 10.1.0 +* Results gathered: 2 April 2021; BLIS and SSL2 updated on 21 Sept 2021 +* Implementations tested: + * BLIS b05279d (post-0.8.1) + * configured with: + * `../configure -t none CFLAGS="-DCACHE_SECTOR_SIZE_READONLY" a64fx` (single-threaded) + * `../configure -t openmp CFLAGS="-DCACHE_SECTOR_SIZE_READONLY" a64fx` (multithreaded) + * sub-configuration exercised: `a64fx` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (12 core) execution requested via `export BLIS_JC_NT=1 BLIS_IC_NT=1 BLIS_JR_NT=12` + * Multithreaded (48 core) execution requested via `export BLIS_JC_NT=1 BLIS_IC_NT=4 BLIS_JR_NT=12` + * Eigen 3.3.9 + * Obtained via the [Eigen GitLab homepage](https://gitlab.com/libeigen/eigen) + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12` + * Multithreaded (48 core) execution requested via `export OMP_NUM_THREADS=48` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * ARMPL (20.1.0 for A64fx) + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12` + * Multithreaded (48 core) execution requested via `export OMP_NUM_THREADS=48` + * **NOTE**: While this version of ARMPL does provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm` (with the exception `dtrsm`), but these implementations yield very low performance, and their long run times led us to skip collecting these data altogether. + * Fujitsu SSL2 (Fujitsu toolchain 1.2.33) + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1 NPARALLEL=1` + * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12 NPARALLEL=12` + * Multithreaded (48 core) execution requested via `export OMP_NUM_THREADS=48 NPARALLEL=48` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="12-23 24-35 36-47 48-59"`. + * All executables were run through `numactl --interleave=all` (multithreaded only). +* Frequency throttling: No change made. No frequency lowering observed. +* Comments: + * Special thanks to Stepan Nassyr and RuQing G. Xu for their work in developing and optimizing A64fx support. Also, thanks to RuQing G. Xu for collecting the data that appear in these graphs. + +### A64fx results + +#### pdf + +* [A64fx single-threaded](graphs/large/l3_perf_a64fx_nt1.pdf) +* [A64fx multithreaded (12 cores)](graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf) +* [A64fx multithreaded (48 cores)](graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf) + +#### png (inline) + +* **A64fx single-threaded** +![single-threaded](graphs/large/l3_perf_a64fx_nt1.png) +* **A64fx multithreaded (12 cores)** +![multithreaded (12 cores)](graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png) +* **A64fx multithreaded (48 cores)** +![multithreaded (48 cores)](graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png) + +--- + +## Neoverse N1 + +### Neoverse N1 experiment details + +* Location: AWS cloud +* Processor model: Graviton2 Neoverse N1 +* Core topology: one socket, 64 cores per socket, 64 cores total +* SMT status: none +* Max clock rate: 2.5GHz (single-core and multicore) +* Max vector register length: 128 bits (NEON) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 20.0 GFLOPS (double-precision), 40.0 GFLOPS (single-precision) + * multicore: 20.0 GFLOPS/core (double-precision), 40.0 GFLOPS/core (single-precision) +* Operating system: unknown +* Page size: unknown +* Compiler: gcc 10.3.0 +* Results gathered: 15 July 2021 +* Implementations tested: + * BLIS fab5c86d (0.8.1-67) + * configured with `./configure -t openmp thunderx2` (single- and multithreaded) + * sub-configuration exercised: `thunderx2` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (64 core) execution requested via `export BLIS_NUM_THREADS=64` + * OpenBLAS 0.3.17 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=64` (multithreaded, 64 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (64 core) execution requested via `export OPENBLAS_NUM_THREADS=64` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-63"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * N/A + +### Neoverse N1 results + +#### pdf + +* [Neoverse N1 single-threaded](graphs/large/l3_perf_nn1_nt1.pdf) +* [Neoverse N1 multithreaded (64 cores)](graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf) + +#### png (inline) + +* **Neoverse N1 single-threaded** +![single-threaded](graphs/large/l3_perf_nn1_nt1.png) +* **Neoverse N1 multithreaded (64 cores)** +![multithreaded (64 cores)](graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png) + +--- + +# Feedback + +Please let us know what you think of these performance results! Similarly, if you have any questions or concerns, or are interested in reproducing these performance experiments on your own hardware, we invite you to [open an issue](https://github.com/flame/blis/issues) and start a conversation with BLIS developers. + +Thanks for your interest in BLIS! + diff --git a/docs/PerformanceSmall.md b/docs/PerformanceSmall.md new file mode 100644 index 0000000000..65b30b1364 --- /dev/null +++ b/docs/PerformanceSmall.md @@ -0,0 +1,487 @@ +# Contents + +* **[Contents](PerformanceSmall.md#contents)** +* **[Introduction](PerformanceSmall.md#introduction)** +* **[General information](PerformanceSmall.md#general-information)** +* **[Interpretation](PerformanceSmall.md#interpretation)** +* **[Reproduction](PerformanceSmall.md#reproduction)** +* **[Level-3 performance](PerformanceSmall.md#level-3-performance)** + * **[Kaby Lake](PerformanceSmall.md#kaby-lake)** + * **[Experiment details](PerformanceSmall.md#kaby-lake-experiment-details)** + * **[Results](PerformanceSmall.md#kaby-lake-results)** + * **[Haswell](PerformanceSmall.md#haswell)** + * **[Experiment details](PerformanceSmall.md#haswell-experiment-details)** + * **[Results](PerformanceSmall.md#haswell-results)** + * **[Zen](PerformanceSmall.md#zen)** + * **[Experiment details](PerformanceSmall.md#zen-experiment-details)** + * **[Results](PerformanceSmall.md#zen-results)** + * **[Zen2](PerformanceSmall.md#zen2)** + * **[Experiment details](PerformanceSmall.md#zen2-experiment-details)** + * **[Results](PerformanceSmall.md#zen2-results)** +* **[Feedback](PerformanceSmall.md#feedback)** + +# Introduction + +This document showcases performance results for the level-3 `gemm` operation +on small matrices with BLIS and BLAS for select hardware architectures. + +# General information + +Generally speaking, for level-3 operations on small matrices, we publish +two "panels" for each type of hardware, one that reflects performance on +row-stored matrices and another for column-stored matrices. +Each panel will consist of a 4x7 grid of graphs, with each row representing +a different transposition case (`nn`, `nt`, `tn`, `tt`) +complex) and each column representing a different shape scenario, usually +with one or two matrix dimensions bound to a fixed size for all problem +sizes tested. +Each of the 28 graphs within a panel will contain an x-axis that reports +problem size, with one, two, or all three matrix dimensions equal to the +problem size (e.g. _m_ = 6; _n_ = _k_, also encoded as `m6npkp`). +The y-axis will report in units GFLOPS (or billions of floating-point operations +per second) per core. + +It's also worth pointing out that the top of some graphs (e.g. the maximum +y-axis value depicted) correspond to the theoretical peak performance +under the conditions associated with that graph, while in other graphs the +y-axis has been adjusted to better show the difference between the various +curves. (We *strongly* prefer to always use peak performance as the top of +the graph; however, this is one of the few exceptions where we feel some +scaling is warranted.) +Theoretical peak performance on a single core, in units of GFLOPS, is +calculated as the product of: +1. the maximum sustainable clock rate in GHz; and +2. the maximum number of floating-point operations (flops) that can be +executed per cycle. + +Note that the maximum sustainable clock rate may change depending on the +conditions. +For example, on some systems the maximum clock rate is higher when only one +core is active (e.g. single-threaded performance) versus when all cores are +active (e.g. multithreaded performance). +The maximum number of flops executable per cycle (per core) is generally +computed as the product of: +1. the maximum number of fused multiply-add (FMA) vector instructions that +can be issued per cycle (per core); +2. the maximum number of elements that can be stored within a single vector +register (for the datatype in question); and +3. 2.0, since an FMA instruction fuses two operations (a multiply and an add). + +Typically, organizations and individuals publish performance with square +matrices, which can miss the problem sizes of interest to many applications. +Here, in addition to square matrices (shown in the seventh column), we also +show six other scenarios where one or two `gemm` dimensions (of _m,_ _n_, and +_k_) is small. In these six columns, the constant small matrix dimensions were +chosen to be _very_ small--in the neighborhood of 8--intentionally to showcase +what happens when at least one of the matrices is abnormally "skinny." + +The problem size range, represented on the x-axis, is sampled in +increments that vary. These increments (and the overall range) are generally +large for the cases where two dimensions are small (and constant), medium for +cases where one dimension is small (and constant), and small for cases where +all dimensions (e.g. _m_, _n_, and _k_) are variable and bound to the problem +size (i.e., square matrices). + +The legend in each graph contains two entries for BLIS, corresponding to the +two black lines, one solid and one dotted. The dotted line, **"BLIS conv"**, +represents the conventional implementation that targets large matrices. This +was the only implementation available in BLIS prior to the addition to the +small/skinny matrix support. The solid line, **"BLIS sup"**, makes use of the +new small/skinny matrix implementation. Sometimes, the performance of +**"BLIS sup"** drops below that of **"BLIS conv"** for somewhat larger problems. +However, in practice, we use a threshold to determine when to switch from the +former to the latter, and therefore the goal is for the performance of +**"BLIS conv"** to serve as an approximate floor below which BLIS performance +never drops. + +Finally, each point along each curve represents the best of three trials. + +# Interpretation + +In general, the the curves associated with higher-performing implementations +will appear higher in the graphs than lower-performing implementations. +Ideally, an implementation will climb in performance (as a function of problem +size) as quickly as possible and asymptotically approach some high fraction of +peak performance. + +When corresponding with us, via email or when opening an +[issue](https://github.com/flame/blis/issues) on github, we kindly ask that +you specify as closely as possible (though a range is fine) your problem +size of interest so that we can better assist you. + +# Reproduction + +In general, we do not offer any step-by-step guide for how to reproduce the +performance graphs shown below. + +That said, if you are keenly interested in running your own performance +benchmarks, either in an attempt to reproduce the results shown here or to +measure performance of different hardware, of different implementations (or +versions), and/or for different problem sizes, you should begin by studying +the source code, `Makefile`, and scripts in +the [test/sup](https://github.com/flame/blis/tree/master/test/sup) directory +of the BLIS source distribution. Then, you'll need to take time to build +and/or install some (or all) of the implementations shown (e.g. +[OpenBLAS](https://github.com/xianyi/OpenBLAS), +[MKL](https://software.intel.com/en-us/mkl/), +[Eigen](http://eigen.tuxfamily.org), +[BLASFEO](https://github.com/giaf/blasfeo), and +[libxsmm](https://github.com/hfp/libxsmm)), including BLIS. Be sure to consult +the detailed notes provided below; they should be *very* helpful in successfully +building the libraries. The `runme.sh` script in `test/sup` (or `test/supmt`) +will help you run +some (or all) of the test drivers produced by the `Makefile`, and the +Matlab/Octave function `plot_panel_trxsh()` defined in the `octave` directory +will help you turn the output of those test drivers into a PDF file of graphs. +The `runthese.m` file will contain example invocations of the function. + +# Level-3 performance + +## Kaby Lake + +### Kaby Lake experiment details + +* Location: undisclosed +* Processor model: Intel Core i5-7500 (Kaby Lake) +* Core topology: one socket, 4 cores total +* SMT status: unavailable +* Max clock rate: 3.8GHz (single-core) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 57.6 GFLOPS (double-precision), 115.2 GFLOPS (single-precision) + * multicore: 57.6 GFLOPS/core (double-precision), 115.2 GFLOPS/core (single-precision) +* Operating system: Gentoo Linux (Linux kernel 5.2.4) +* Page size: 4096 bytes +* Compiler: gcc 8.3.0 +* Results gathered: 3 March 2020 +* Implementations tested: + * BLIS 90db88e (0.6.1-8) + * configured with `./configure --enable-cblas auto` (single-threaded) + * configured with `./configure --enable-cblas -t openmp auto` (multithreaded) + * sub-configuration exercised: `haswell` + * Multithreaded (4 cores) execution requested via `export BLIS_NUM_THREADS=4` + * OpenBLAS 0.3.8 + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0 USE_LOCKING=1` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=4` (multithreaded) + * Multithreaded (4 cores) execution requested via `export OPENBLAS_NUM_THREADS=4` + * BLASFEO f9b78c6 + * configured `Makefile.rule` with: `BLAS_API=1 FORTRAN_BLAS_API=1 CBLAS_API=1`. + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (36b9596) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 67. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; CC=gcc cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (4 cores) execution requested via `export OMP_NUM_THREADS=4` + * MKL 2020 initial release + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (4 cores) execution requested via `export MKL_NUM_THREADS=4` + * libxsmm a40a833 (post-1.14) + * compiled with `make AVX=2`; linked with [netlib BLAS](http://www.netlib.org/blas/) 3.6.0 as the fallback library to better show where libxsmm stops handling the computation internally. +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-3"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * Driver: intel_pstate + * Governor: performance + * Hardware limits: 800MHz - 3.8GHz + * Adjusted minimum: 3.8GHz +* Comments: + * libxsmm is highly competitive for very small problems, but quickly gives up once the "large" dimension exceeds about 180-240 (or 64 in the case where all operands are square). Also, libxsmm's `gemm` cannot handle a transposition on matrix A and similarly dispatches the fallback implementation for those cases. libxsmm also does not export CBLAS interfaces, and therefore only appears on the graphs for column-stored matrices. + +### Kaby Lake results + +#### pdf + +* [Kaby Lake single-threaded row-stored](graphs/sup/dgemm_rrr_kbl_nt1.pdf) +* [Kaby Lake single-threaded column-stored](graphs/sup/dgemm_ccc_kbl_nt1.pdf) +* [Kaby Lake multithreaded (4 cores) row-stored](graphs/sup/dgemm_rrr_kbl_nt4.pdf) +* [Kaby Lake multithreaded (4 cores) column-stored](graphs/sup/dgemm_ccc_kbl_nt4.pdf) + +#### png (inline) + +* **Kaby Lake single-threaded row-stored** +![single-threaded row-stored](graphs/sup/dgemm_rrr_kbl_nt1.png) +* **Kaby Lake single-threaded column-stored** +![single-threaded column-stored](graphs/sup/dgemm_ccc_kbl_nt1.png) +* **Kaby Lake multithreaded (4 cores) row-stored** +![multithreaded row-stored](graphs/sup/dgemm_rrr_kbl_nt4.png) +* **Kaby Lake multithreaded (4 cores) column-stored** +![multithreaded column-stored](graphs/sup/dgemm_ccc_kbl_nt4.png) + +--- + +## Haswell + +### Haswell experiment details + +* Location: TACC (Lonestar5) +* Processor model: Intel Xeon E5-2690 v3 (Haswell) +* Core topology: two sockets, 12 cores per socket, 24 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.5GHz (single-core), 3.1GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 56 GFLOPS (double-precision), 112 GFLOPS (single-precision) + * multicore: 49.6 GFLOPS/core (double-precision), 99.2 GFLOPS/core (single-precision) +* Operating system: Cray Linux Environment 6 (Linux kernel 4.4.103) +* Page size: 4096 bytes +* Compiler: gcc 7.3.0 +* Results gathered: 3 March 2020 +* Implementations tested: + * BLIS 90db88e (0.6.1-8) + * configured with `./configure --enable-cblas auto` (single-threaded) + * configured with `./configure --enable-cblas -t openmp auto` (multithreaded) + * sub-configuration exercised: `haswell` + * Multithreaded (12 cores) execution requested via `export BLIS_NUM_THREADS=12` + * OpenBLAS 0.3.8 + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0 USE_LOCKING=1` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=12` (multithreaded) + * Multithreaded (12 cores) execution requested via `export OPENBLAS_NUM_THREADS=12` + * BLASFEO f9b78c6 + * configured `Makefile.rule` with: `BLAS_API=1 FORTRAN_BLAS_API=1 CBLAS_API=1`. + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (36b9596) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 67. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; CC=gcc cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (12 cores) execution requested via `export OMP_NUM_THREADS=12` + * MKL 2020 initial release + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (12 cores) execution requested via `export MKL_NUM_THREADS=12` + * libxsmm a40a833 (post-1.14) + * compiled with `make AVX=2`; linked with [netlib BLAS](http://www.netlib.org/blas/) 3.6.0 as the fallback library to better show where libxsmm stops handling the computation internally. +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-11"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * libxsmm is highly competitive for very small problems, but quickly gives up once the "large" dimension exceeds about 180-240 (or 64 in the case where all operands are square). Also, libxsmm's `gemm` cannot handle a transposition on matrix A and similarly dispatches the fallback implementation for those cases. libxsmm also does not export CBLAS interfaces, and therefore only appears on the graphs for column-stored matrices. + +### Haswell results + +#### pdf + +* [Haswell single-threaded row-stored](graphs/sup/dgemm_rrr_has_nt1.pdf) +* [Haswell single-threaded column-stored](graphs/sup/dgemm_ccc_has_nt1.pdf) +* [Haswell multithreaded (12 cores) row-stored](graphs/sup/dgemm_rrr_has_nt12.pdf) +* [Haswell multithreaded (12 cores) column-stored](graphs/sup/dgemm_ccc_has_nt12.pdf) + +#### png (inline) + +* **Haswell single-threaded row-stored** +![single-threaded row-stored](graphs/sup/dgemm_rrr_has_nt1.png) +* **Haswell single-threaded column-stored** +![single-threaded column-stored](graphs/sup/dgemm_ccc_has_nt1.png) +* **Haswell multithreaded (12 cores) row-stored** +![multithreaded row-stored](graphs/sup/dgemm_rrr_has_nt12.png) +* **Haswell multithreaded (12 cores) column-stored** +![multithreaded column-stored](graphs/sup/dgemm_ccc_has_nt12.png) + +--- + +## Zen + +### Zen experiment details + +* Location: Oracle cloud +* Processor model: AMD Epyc 7551 (Zen1) +* Core topology: two sockets, 4 dies per socket, 2 core complexes (CCX) per die, 4 cores per CCX, 64 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.0GHz (single-core), 2.55GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 1 + * Alternatively, FMA vector IPC is 2 when vectors are limited to 128 bits each. +* Peak performance: + * single-core: 24 GFLOPS (double-precision), 48 GFLOPS (single-precision) + * multicore: 20.4 GFLOPS/core (double-precision), 40.8 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Page size: 4096 bytes +* Compiler: gcc 7.4.0 +* Results gathered: 3 March 2020 +* Implementations tested: + * BLIS 90db88e (0.6.1-8) + * configured with `./configure --enable-cblas auto` (single-threaded) + * configured with `./configure --enable-cblas -t openmp auto` (multithreaded) + * sub-configuration exercised: `zen` + * Multithreaded (32 cores) execution requested via `export BLIS_NUM_THREADS=32` + * OpenBLAS 0.3.8 + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0 USE_LOCKING=1` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=32` (multithreaded) + * Multithreaded (32 cores) execution requested via `export OPENBLAS_NUM_THREADS=32` + * BLASFEO f9b78c6 + * configured `Makefile.rule` with: `BLAS_API=1 FORTRAN_BLAS_API=1 CBLAS_API=1`. + * built BLAS library via `make CC=gcc` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (36b9596) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 67. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; CC=gcc cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (32 cores) execution requested via `export OMP_NUM_THREADS=32` + * MKL 2020 initial release + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (32 cores) execution requested via `export MKL_NUM_THREADS=32` + * libxsmm a40a833 (post-1.14) + * compiled with `make AVX=2`; linked with [netlib BLAS](http://www.netlib.org/blas/) 3.6.0 as the fallback library to better show where libxsmm stops handling the computation internally. +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-31"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits: 1.2GHz - 2.0GHz + * Adjusted minimum: 2.0GHz +* Comments: + * libxsmm is highly competitive for very small problems, but quickly gives up once the "large" dimension exceeds about 180-240 (or 64 in the case where all operands are square). Also, libxsmm's `gemm` cannot handle a transposition on matrix A and similarly dispatches the fallback implementation for those cases. libxsmm also does not export CBLAS interfaces, and therefore only appears on the graphs for column-stored matrices. + +### Zen results + +#### pdf + +* [Zen single-threaded row-stored](graphs/sup/dgemm_rrr_zen_nt1.pdf) +* [Zen single-threaded column-stored](graphs/sup/dgemm_ccc_zen_nt1.pdf) +* [Zen multithreaded (32 cores) row-stored](graphs/sup/dgemm_rrr_zen_nt32.pdf) +* [Zen multithreaded (32 cores) column-stored](graphs/sup/dgemm_ccc_zen_nt32.pdf) + +#### png (inline) + +* **Zen single-threaded row-stored** +![single-threaded row-stored](graphs/sup/dgemm_rrr_zen_nt1.png) +* **Zen single-threaded column-stored** +![single-threaded column-stored](graphs/sup/dgemm_ccc_zen_nt1.png) +* **Zen multithreaded (32 cores) row-stored** +![multithreaded row-stored](graphs/sup/dgemm_rrr_zen_nt32.png) +* **Zen multithreaded (32 cores) column-stored** +![multithreaded column-stored](graphs/sup/dgemm_ccc_zen_nt32.png) + +--- + +## Zen2 + +### Zen2 experiment details + +* Location: Oracle cloud +* Processor model: AMD Epyc 7742 (Zen2 "Rome") +* Core topology: two sockets, 8 Core Complex Dies (CCDs) per socket, 2 Core Complexes (CCX) per CCD, 4 cores per CCX, 128 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 2.25GHz (base, documented); 3.4GHz boost (single-core, documented); 2.6GHz boost (multicore, estimated) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 + * Alternatively, FMA vector IPC is 4 when vectors are limited to 128 bits each. +* Peak performance: + * single-core: 54.4 GFLOPS (double-precision), 108.8 GFLOPS (single-precision) + * multicore (estimated): 41.6 GFLOPS/core (double-precision), 83.2 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Page size: 4096 bytes +* Compiler: gcc 9.3.0 +* Results gathered: 8 October 2020 +* Implementations tested: + * BLIS a0849d3 (0.7.0-67) + * configured with `./configure --enable-cblas auto` (single-threaded) + * configured with `./configure --enable-cblas -t openmp auto` (multithreaded) + * sub-configuration exercised: `zen2` + * Multithreaded (32 cores) execution requested via `export BLIS_NUM_THREADS=32` + * OpenBLAS 0.3.10 + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0 USE_LOCKING=1` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=32` (multithreaded) + * Multithreaded (32 cores) execution requested via `export OPENBLAS_NUM_THREADS=32` + * BLASFEO 5b26d40 + * configured `Makefile.rule` with: `BLAS_API=1 FORTRAN_BLAS_API=1 CBLAS_API=1`. + * built BLAS library via `make CC=gcc` + * Eigen 3.3.90 + * Obtained via the [Eigen GitLab homepage](https://gitlab.com/libeigen/eigen) (24 September 2020) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal): + ``` + # These lines added after line 60. + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE) + if(COMPILER_SUPPORTS_MARCH_NATIVE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() + ``` + * configured and built BLAS library via `mkdir build; cd build; CC=gcc cmake ..; make blas` + * installed headers via `cmake . -DCMAKE_INSTALL_PREFIX=$HOME/flame/eigen; make install` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (32 cores) execution requested via `export OMP_NUM_THREADS=32` + * MKL 2020 update 3 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (32 cores) execution requested via `export MKL_NUM_THREADS=32` + * libxsmm f0ab9cb (post-1.16.1) + * compiled with `make AVX=2`; linked with [netlib BLAS](http://www.netlib.org/blas/) 3.6.0 as the fallback library to better show where libxsmm stops handling the computation internally. +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-31"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. + * All executables were run through `numactl --interleave=all`. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits (steps): 1.5GHz, 2.0GHz, 2.25GHz + * Adjusted minimum: 2.25GHz +* Comments: + * None. + +### Zen2 results + +#### pdf + +* [Zen2 sgemm single-threaded row-stored](graphs/sup/sgemm_rrr_zen2_nt1.pdf) +* [Zen2 sgemm single-threaded column-stored](graphs/sup/sgemm_ccc_zen2_nt1.pdf) +* [Zen2 dgemm single-threaded row-stored](graphs/sup/dgemm_rrr_zen2_nt1.pdf) +* [Zen2 dgemm single-threaded column-stored](graphs/sup/dgemm_ccc_zen2_nt1.pdf) +* [Zen2 sgemm multithreaded (32 cores) row-stored](graphs/sup/sgemm_rrr_zen2_nt32.pdf) +* [Zen2 sgemm multithreaded (32 cores) column-stored](graphs/sup/sgemm_ccc_zen2_nt32.pdf) +* [Zen2 dgemm multithreaded (32 cores) row-stored](graphs/sup/dgemm_rrr_zen2_nt32.pdf) +* [Zen2 dgemm multithreaded (32 cores) column-stored](graphs/sup/dgemm_ccc_zen2_nt32.pdf) + +#### png (inline) + +* **Zen2 sgemm single-threaded row-stored** +![sgemm single-threaded row-stored](graphs/sup/sgemm_rrr_zen2_nt1.png) +* **Zen2 sgemm single-threaded column-stored** +![sgemm single-threaded column-stored](graphs/sup/sgemm_ccc_zen2_nt1.png) +* **Zen2 dgemm single-threaded row-stored** +![dgemm single-threaded row-stored](graphs/sup/dgemm_rrr_zen2_nt1.png) +* **Zen2 dgemm single-threaded column-stored** +![dgemm single-threaded column-stored](graphs/sup/dgemm_ccc_zen2_nt1.png) +* **Zen2 sgemm multithreaded (32 cores) row-stored** +![sgemm multithreaded row-stored](graphs/sup/sgemm_rrr_zen2_nt32.png) +* **Zen2 sgemm multithreaded (32 cores) column-stored** +![sgemm multithreaded column-stored](graphs/sup/sgemm_ccc_zen2_nt32.png) +* **Zen2 dgemm multithreaded (32 cores) row-stored** +![dgemm multithreaded row-stored](graphs/sup/dgemm_rrr_zen2_nt32.png) +* **Zen2 dgemm multithreaded (32 cores) column-stored** +![dgemm multithreaded column-stored](graphs/sup/dgemm_ccc_zen2_nt32.png) + +--- + +# Feedback + +Please let us know what you think of these performance results! Similarly, if you have any questions or concerns, or are interested in reproducing these performance experiments on your own hardware, we invite you to [open an issue](https://github.com/flame/blis/issues) and start a conversation with BLIS developers. + +Thanks for your interest in BLIS! + diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index 193de93420..ccb4d9f0ed 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -4,6 +4,13 @@ ## Contents +* [Changes in 0.9.0](ReleaseNotes.md#changes-in-090) +* [Changes in 0.8.1](ReleaseNotes.md#changes-in-081) +* [Changes in 0.8.0](ReleaseNotes.md#changes-in-080) +* [Changes in 0.7.0](ReleaseNotes.md#changes-in-070) +* [Changes in 0.6.1](ReleaseNotes.md#changes-in-061) +* [Changes in 0.6.0](ReleaseNotes.md#changes-in-060) +* [Changes in 0.5.2](ReleaseNotes.md#changes-in-052) * [Changes in 0.5.1](ReleaseNotes.md#changes-in-051) * [Changes in 0.5.0](ReleaseNotes.md#changes-in-050) * [Changes in 0.4.1](ReleaseNotes.md#changes-in-041) @@ -33,6 +40,362 @@ * [Changes in 0.0.2](ReleaseNotes.md#changes-in-002) * [Changes in 0.0.1](ReleaseNotes.md#changes-in-001) +## Changes in 0.9.0 +April 1, 2022 + +Improvements present in 0.9.0: + +Framework: +- Added various fields to `obj_t` that relate to storing function pointers to custom `packm` kernels, microkernels, etc as well as accessor functions to set and query those fields. (Devin Matthews) +- Enabled user-customized `packm` microkernels and variants via the aforementioned new `obj_t` fields. (Devin Matthews) +- Moved edge-case handling out of the macrokernel and into the `gemm` and `gemmtrsm` microkernels. This also required updating of APIs and definitions of all existing microkernels in `kernels` directory. Edge-case handling functionality is now facilitated via new preprocessor macros found in `bli_edge_case_macro_defs.h`. (Devin Matthews) +- Avoid `gemmsup` thread barriers when not packing A or B. This boosts performance for many small multithreaded problems. (Field Van Zee, AMD) +- Allow the 1m method to operate normally when single and double real-domain microkernels mix row and column I/O preference. (Field Van Zee, Devin Matthews, RuQing Xu) +- Removed support for execution of complex-domain level-3 operations via the 3m and 4m methods. +- Refactored `herk`, `her2k`, `syrk`, `syr2k` in terms of `gemmt`. (Devin Matthews) +- Defined `setijv` and `getijv` to set/get vector elements. +- Defined `eqsc`, `eqv`, and `eqm` operations to test equality between two scalars, vectors, or matrices. +- Added new bounds checking to `setijm` and `getijm` to prevent use of negative indices. +- Renamed `membrk` files/variables/functions to `pba`. +- Store error-checking level as a thread-local variable. (Devin Matthews) +- Add `err_t*` "return" parameter to `bli_malloc_*()` and friends. +- Switched internal mutexes of the `sba` and `pba` to static initialization. +- Changed return value method of `bli_pack_get_pack_a()`, `bli_pack_get_pack_b()`. +- Fixed a bug that allows `bli_init()` to be called more than once (without segfaulting). (@lschork2, Minh Quan Ho, Devin Matthews) +- Removed a sanity check in `bli_pool_finalize()` that prevented BLIS from being re-initialized. (AMD) +- Fixed insufficient `pool_t`-growing logic in `bli_pool.c`, and always allocate at least one element in `.block_ptrs` array. (Minh Quan Ho) +- Cleanups related to the error message array in `bli_error.c`. (Minh Quan Ho) +- Moved language-related definitions from `bli_macro_defs.h` to a new header, `bli_lang_defs.h`. +- Renamed `BLIS_SIMD_NUM_REGISTERS` to `BLIS_SIMD_MAX_NUM_REGISTERS` and `BLIS_SIMD_SIZE` to `BLIS_SIMD_MAX_SIZE` for improved clarity. (Devin Matthews) +- Many minor bugfixes. +- Many cleanups, including removal of old and commented-out code. + +Compatibility: +- Expanded BLAS layer to include support for `?axpby_()` and `?gemm_batch_()`. (Meghana Vankadari, AMD) +- Added `gemm3m` APIs to BLAS and CBLAS layers. (Bhaskar Nallani, AMD) +- Handle `?gemm_()` invocations where m or n is unit by calling `?gemv_()`. (Dipal M Zambare, AMD) +- Removed option to finalize BLIS after every BLAS call. +- Updated default definitions of `bli_slamch()` and `bli_dlamch()` to use constants from standard C library rather than values computed at runtime. (Devin Matthews) + +Kernels: +- Added 512-bit SVE-based `a64fx` subconfiguration that uses empirically-tuned blocksizes (Stepan Nassyr, RuQing Xu) +- Added a vector-length agnostic `armsve` subconfig that computes blocksizes via an analytical model. (Stepan Nassyr) +- Added vector-length agnostic d/s/sh `gemm` kernels for Arm SVE. (Stepan Nassyr) +- Added `gemmsup` kernels to the `armv8a` kernel set for use in new Apple Firestorm subconfiguration. (RuQing Xu) +- Added 512-bit SVE `dpackm` kernels (16xk and 10xk) with in-register transpose. (RuQing Xu) +- Extended 256-bit SVE `dpackm` kernels by Linaro Ltd. to 512-bit for size 12xk. (RuQing Xu) +- Reorganized register usage in `bli_gemm_armv8a_asm_d6x8.c` to accommodate clang. (RuQing Xu) +- Added `saxpyf`/`daxpyf`/`caxpyf` kernels to `zen` kernel set. (Dipal M Zambare, AMD) +- Added `vzeroupper` instruction to `haswell` microkernels. (Devin Matthews) +- Added explicit `beta == 0` handling in s/d `armsve` and `armv7a` `gemm` microkernels. (Devin Matthews) +- Added a unique tag to branch labels to accommodate clang. (Devin Matthews, Jeff Hammond) +- Fixed a copy-paste bug in the loading of `kappa_i` in the two assembly `cpackm` kernels in `haswell` kernel set. (Devin Matthews) +- Fixed a bug in Mx1 `gemmsup` `haswell` kernels whereby the `vhaddpd` instruction is used with uninitialized registers. (Devin Matthews) +- Fixed a bug in the `power10` microkernel I/O. (Nicholai Tukanov) +- Many other Arm kernel updates and fixes. (RuQing Xu) + +Extras: +- Added support for addons, which are similar to sandboxes but do not require the user to implement any particular operation. +- Added a new `gemmlike` sandbox to allow rapid prototyping of `gemm`-like operations. +- Various updates and improvements to the `power10` sandbox, including a new testsuite. (Nicholai Tukanov) + +Build system: +- Added explicit support for AMD's Zen3 microarchitecture. (Dipal M Zambare, AMD, Field Van Zee) +- Added runtime microarchitecture detection for Arm. (Dave Love, RuQing Xu, Devin Matthews) +- Added a new `configure` option `--[en|dis]able-amd-frame-tweaks` that allows BLIS to compile certain framework files (each with the `_amd` suffix) that have been customized by AMD for improved performance (provided that the targeted configuration is eligible). By default, the more portable counterparts to these files are compiled. (Field Van Zee, AMD) +- Added an explicit compiler predicate (`is_win`) for Windows in `configure`. (Devin Matthews) +- Use `-march=haswell` instead of `-march=skylake-avx512` on Windows. (Devin Matthews, @h-vetinari) +- Fixed `configure` breakage on MacOSX by accepting either `clang` or `LLVM` in vendor string. (Devin Matthews) +- Blacklist clang10/gcc9 and older for `armsve` subconfig. +- Added a `configure` option to control whether or not to use `@rpath`. (Devin Matthews) +- Added armclang detection to `configure`. (Devin Matthews) +- Use `@path`-based install name on MacOSX and use relocatable `RPATH` entries for testsuite binaries. (Devin Matthews) +- For environment variables `CC`, `CXX`, `FC`, `PYTHON`, `AR`, and `RANLIB`, `configure` will now print an error message and abort if a user specifies a specific tool and that tool is not found. (Field Van Zee, Devin Matthews) +- Added symlink to `blis.pc.in` for out-of-tree builds. (Andrew Wildman) +- Register optimized real-domain `copyv`, `setv`, and `swapv` kernels in `zen` subconfig. (Dipal M Zambare, AMD) +- Added Apple Firestorm (A14/M1) subconfiguration, `firestorm`. (RuQing Xu) +- Added `armsve` subconfig to `arm64` configuration family. (RuQing Xu) +- Allow using clang with the `thunderx2` subconfiguration. (Devin Matthews) +- Fixed a subtle substitution bug in `configure`. (Chengguo Sun) +- Updated top-level Makefile to reflect a dependency on the "flat" `blis.h` file for the BLIS and BLAS testsuite objects. (Devin Matthews) +- Mark `xerbla_()` as a "weak" symbol on MacOSX. (Devin Matthews) +- Fixed a long-standing bug in `common.mk` whereby the header path to `cblas.h` was omitted from the compiler flags when compiling CBLAS files within BLIS. +- Added a custom-made recursive `sed` script to `build` directory. +- Minor cleanups and fixes to `configure`, `common.mk`, and others. + +Testing: +- Fixed a race condition in the testsuite when the SALT option (simulate application-level threading) is enabled. (Devin Matthews) +- Test 1m method execution during `make check`. (Devin Matthews) +- Test `make install` in Travis CI. (Devin Matthews) +- Test C++ in Travis CI to make sure `blis.h` is C++-compatible. (Devin Matthews) +- Disabled SDE testing of pre-Zen microarchitectures via Travis CI. +- Added Travis CI support for testing Arm SVE. (RuQing Xu) +- Updated SDE usage so that it is downloaded from a separate repository (ci-utils) in our GitHub organization. (Field Van Zee, Devin Matthews) +- Updated octave scripts in `test/3` to be robust against missing datasets as well as to fixed a few minor issues. +- Added `test_axpbyv.c` and `test_gemm_batch.c` test driver files to `test` directory. (Meghana Vankadari, AMD) +- Support all four datatypes in `her`, `her2`, `herk`, and `her2k` drivers in `test` directory. (Madan mohan Manokar, AMD) + +Documentation: +- Added documentation for: `setijv`, `getijv`, `eqsc`, `eqv`, `eqm`. +- Added `docs/Addons.md`. +- Added dedicated "Performance" and "Example Code" sections to `README.md`. +- Updated `README.md`. +- Updated `docs/Sandboxes.md`. +- Updated `docs/Multithreading.md`. (Devin Matthews) +- Updated `docs/KernelHowTo.md`. +- Updated `docs/Performance.md` to report Fujitsu A64fx (512-bit SVE) results. (RuQing Xu) +- Updated `docs/Performance.md` to report Graviton2 Neoverse N1 results. (Nicholai Tukanov) +- Updated `docs/FAQ.md` with new questions. +- Fixed typos in `docs/FAQ.md`. (Gaëtan Cassiers) +- Various other minor fixes. + +## Changes in 0.8.1 +March 22, 2021 + +Improvements present in 0.8.1: + +Framework: +- Implemented an automatic reduction in the number of threads when the user requests parallelism via a single number (ie: the automatic way) and (a) that number of threads is prime, and (b) that number exceeds a minimum threshold defined by the macro `BLIS_NT_MAX_PRIME`, which defaults to 11. If prime numbers are really desired, this feature may be suppressed by defining the macro `BLIS_ENABLE_AUTO_PRIME_NUM_THREADS` in the appropriate configuration family's `bli_family_*.h`. (Jeff Diamond) +- Changed default value of `BLIS_THREAD_RATIO_M` from 2 to 1, which leads to slightly different automatic thread factorizations. +- Enable the 1m method only if the real domain microkernel is not a reference kernel. BLIS now forgoes use of 1m if both the real and complex domain kernels are reference implementations. +- Relocated the general stride handling for `gemmsup`. This fixed an issue whereby `gemm` would fail to trigger to conventional code path for cases that use general stride even after `gemmsup` rejected the problem. (RuQing Xu) +- Disabled AMD's small matrix handling entry points for `syrk` and `trsm` due to lack of testing on our side. +- Fixed an incorrect function signature (and prototype) of `bli_?gemmt()`. (RuQing Xu) +- Redefined `BLIS_NUM_ARCHS` to be part of the `arch_t` enum, which means it will be updated automatically when defining future subconfigs. +- Minor code consolidation in all level-3 `_front()` functions. +- Reorganized Windows cpp branch of `bli_pthreads.c`. +- Implemented `bli_pthread_self()` and `_equals()`, but left them commented out (via cpp guards) due to issues with getting the Windows versions working. Thankfully, these functions aren't yet needed by BLIS. + +Kernels: +- Added low-precision POWER10 `gemm` kernels via a `power10` sandbox. This sandbox also provides an API for implementations that use these kernels. See the `sandbox/power10/POWER10.md` document for more info. (Nicholai Tukanov) +- Added assembly `packm` kernels for the `haswell` kernel set and registered to `haswell`, `zen`, and `zen2` subconfigs accordingly. The `s`, `c`, and `z` kernels were modeled on the `d` kernel, which was contributed by AMD. +- Reduced KC in the `skx` subconfig from 384 to 256. (Tze Meng Low) +- Fixed bugs in two `haswell` dgemmsup kernels, which involved extraneous assembly instructions left over from when the kernels were first written. (Kiran Varaganti, Bhaskar Nallani) +- Minor updates to all of the `gemmtrsm` kernels to allow division by diagonal elements rather that scaling by pre-inverted elements. This change was applied to `haswell` and `penryn` kernel sets as well as reference kernels, 1m kernels, and the pre-broadcast B (bb) format kernels used by the `power9` subconfig. (Bhaskar Nallani) +- Fixed incorrect return type on `bli_diag_offset_with_trans()`. (Devin Matthews) + +Build system: +- Output a pkgconfig file so that CMake users that use BLIS can find and incorporate BLIS build products. (Ajay Panyala) +- Fixed an issue in the the configure script's kernel-to-config map that caused `skx` kernel flags to be used when compiling kernels from the `zen` kernel set. This issue wasn't really fixed, but rather tweaked in such a way that it happens to now work. A more proper fix would require a serious rethinking of the configuration system. (Devin Matthews) +- Fixed the shared library build rule in top-level Makefile. The previous rule was incorrectly only linking prerequisites that were newer than the target (`$?`) rather than correctly linking all prerequisites (`$^`). (Devin Matthews) +- Fixed `cc_vendor` for crosstool-ng toolchains. (Isuru Fernando) +- Allow disabling of `trsm` diagonal pre-inversion at compile time via `--disable-trsm-preinversion`. + +Testing: +- Fixed obscure testsuite bug for the `gemmt` test module that relates to its dependency on `gemv`. +- Allow the `amaxv` testsuite module to run with a dimension of 0. (Meghana Vankadari) + +Documentation: +- Documented auto-reduction for prime numbers of threads in `docs/Multithreading.md`. +- Fixed a missing `trans_t` argument in the API documentation for `her2k`/`syr2k` in `docs/BLISTypedAPI.md`. (RuQing Xu) +- Removed an extra call to `free()` in the level-1v typed API example code. (Ilknur Mustafazade) + +## Changes in 0.8.0 +November 19, 2020 + +Improvements present in 0.8.0: + +Framework: +- Implemented support for the level-3 operation `gemmt`, which performs a `gemm` on only the lower or only the upper triangle of a square matrix C. For now, only the conventional/large code path (and not the sup code path) is provided. This support also includes `gemmt` APIs in the BLAS and CBLAS compatibility layers. (AMD) +- Added a C++ template header, `blis.hh`, containing a BLAS-inspired wrapper to a set of polymorphic CBLAS-like function wrappers defined in another header, `cblas.hh`. These headers are installed only when running the `install` target with `INSTALL_HH` set to `yes`. (AMD) +- Disallow `randv`, `randm`, `randnv`, and `randnm` from producing vectors and matrices with 1-norms of zero. +- Changed the behavior of user-initialized `rntm_t` objects so that packing of A and B is disabled by default. (Kiran Varaganti) +- Transitioned to using `bool` keyword instead of the previous integer-based `bool_t` typedef. (RuQing Xu) +- Updated all inline function definitions to use the cpp macro `BLIS_INLINE` instead of the `static` keyword. (Giorgos Margaritis, Devin Matthews) +- Relocated `#include "cpuid.h"` directive from `bli_cpuid.h` to `bli_cpuid.c` so that applications can `#include` both `blis.h` and `cpuid.h`. (Bhaskar Nallani, Devin Matthews) +- Defined `xerbla_array_()` to complement the netlib routine `xerbla_array()`. (Isuru Fernando) +- Replaced the previously broken `ref99` sandbox with a simpler, functioning alternative. (Francisco Igual) +- Fixed a harmless bug whereby `herk` was calling `trmm`-related code for determining the blocksize of KC in the 4th loop. + +Kernels: +- Implemented a full set of `sgemmsup` assembly millikernels and microkernels for `haswell` kernel set. +- Implemented POWER10 `sgemm` and `dgemm` microkernels. (Nicholai Tukanov) +- Added two kernels (`dgemm` and `dpackm`) that employ ARM SVE vector extensions. (Guodong Xu) +- Implemented explicit beta = 0 handling in the `sgemm` microkernel in `bli_gemm_armv7a_int_d4x4.c`. This omission was causing testsuite failures in the new `gemmt` testsuite module for `cortexa15` builds given that the `gemmt` correctness check relies on `gemm` with beta = 0. +- Updated `void*` function arguments in reference `packm` kernels to use the native pointer type, and fixed a related dormant type bug in `bli_kernels_knl.h`. +- Fixed missing `restrict` qualifier in `sgemm` microkernel prototype for `knl` kernel set header. +- Added some missing n = 6 edge cases to `dgemmsup` kernels. +- Fixed an erroneously disabled edge case optimization in `gemmsup` variant code. +- Various bugfixes and cleanups to `dgemmsup` kernels. + +Build system: +- Implemented runtime subconfiguration selection override via `BLIS_ARCH_TYPE`. (decandia50) +- Output the python found during `configure` into the `PYTHON` variable set in `build/config.mk`. (AMD) +- Added configure support for Intel oneAPI via the `CC` environment variable. (Ajay Panyala, Devin Matthews) +- Use `-O2` for all framework code, potentially avoiding intermitten issues with `f2c`'ed packed and banded code. (Devin Matthews) +- Tweaked `zen2` subconfiguration's cache blocksizes and registered full suite of `sgemm` and `dgemm` millikernels. +- Use the `-fomit-frame-pointer` compiler optimization option in the `haswell` and `skx` subconfigurations. (Jeff Diamond, Devin Matthews) +- Tweaked Makefiles in `test`, `test/3`, and `test/sup` so that running any of the usual targets without having first built BLIS results in a helpful error message. +- Add support for `--complex-return=[gnu|intel]` to `configure`, which allows the user to toggle between the GNU and Intel return value conventions for functions such as `cdotc`, `cdotu`, `zdotc`, and `zdotu`. +- Updates to `cortexa9`, `cortexa53` compilation flags. (Dave Love) + +Testing: +- Added a `gemmt` module to the testsuite and a standalone test driver to the `test` directory, both of which exercise the new `gemmt` functionality. (AMD) +- Support creating matrices with small or large leading dimensions in `test/sup` test drivers. +- Support executing `test/sup` drivers with unpacked or packed matrices. +- Added optional `numactl` usage to `test/3/runme.sh`. +- Updated and/or consolidated octave scripts in `test/3` and `test/sup`. +- Increased `dotxaxpyf` testsuite thresholds to avoid false `MARGINAL` results during normal execution. (nagsingh) + +Documentation: +- Added Epyc 7742 Zen2 ("Rome") performance results (single- and multithreaded) to `Performance.md` and `PerformanceSmall.md`. (Jeff Diamond) +- Documented `gemmt` APIs in `BLISObjectAPI.md` and `BLISTypedAPI.md`. (AMD) +- Documented commonly-used object mutator functions in `BLISObjectAPI.md`. (Jeff Diamond) +- Relocated the operation indices of `BLISObjectAPI.md` and `BLISTypedAPI.md` to appear immediately after their respective tables of contents. (Jeff Diamond) +- Added missing perl prerequisite to `BuildSystem.md`. (pkubaj, Dilyn Corner) +- Fixed missing `conjy` parameter in `BLISTypedAPI.md` documentation for `her2` and `syr2`. (Robert van de Geijn) +- Fixed incorrect link to `shiftd` in `BLISTypedAPI.md`. (Jeff Diamond) +- Mention example code at the top of `BLISObjectAPI.md` and `BLISTypedAPI.md`. +- Minor updates to `README.md`, `FAQ.md`, `Multithreading.md`, and `Sandboxes.md` documents. + +## Changes in 0.7.0 +April 7, 2020 + +Improvements present in 0.7.0: + +Framework: +- Implemented support for multithreading within the sup (skinny/small/unpacked) framework, which previously was single-threaded only. Note that this feature works harmoniously with the selective packing introduced into the sup framework in 0.6.1. (AMD) +- Renamed `bli_thread_obarrier()` and `bli_thread_obroadcast()` functions to drop the 'o', which was left over from when `thrcomm_t` objects tracked both "inner" and "outer" communicators. +- Fixed an obscure `int`-to-`packbuf_t` type conversion error that only affects certain C++ compilers (including g++) when compiling application code that includes the BLIS header file `blis.h`. (Ajay Panyala) +- Added a missing early `return` statement in `bli_thread_partition_2x2()`, which provides a slight optimization. (Kiran Varaganti) + +Kernels: +- Fixed the semantics of the `bli_amaxv()` kernels ('s' and 'd') within the `zen` kernel set. Previously, the kernels (incorrectly) returned the index of the last element whose absolute value was largest (in the event there were multiple of equal value); now, it (correclty) returns the index of the first of such elements. The kernels also now return the index of the first NaN, if one is encountered. (Mat Cross, Devin Matthews) + +Build system: +- Warn the user at configure-time when hardware auto-detection returns the `generic` subconfiguration since this is probably not what they were expecting. (Devin Matthews) +- Removed unnecessary sorting (and duplicate removal) on `LDFLAGS` in `common.mk`. (Isuru Fernando) +- Specify the full path to the location of the dynamic library on OSX so that other dynamic libraries that depend on BLIS know where to find the library. (Satish Balay, Jed Brown) + +Testing: +- Updated and reorganized test drivers in `test/sup` so that they work for either single-threaded or multithreaded purposes. (AMD) +- Updated/optimized octave scripts in `test/sup` for use with octave 5.2.0. +- Minor updates/tweaks to `test/1m4m`. + +Documentation: +- Updated existing single-threaded sup performance graphs with new data and added multithreaded sup graphs to `docs/PerformanceSmall.md`. +- Added mention of Gentoo support under the external packages section of the `README.md`. +- Tweaks to `docs/Multithreading.md` that clarify that setting any `BLIS_*_NT` variable to 1 will be considered manual specification for the purposes of determining whether to auto-factorize via `BLIS_NUM_THREADS`. (AMD) + +## Changes in 0.6.1 +January 14, 2020 + +Improvements present in 0.6.1: + +Framework: +- Added support for pre-broadcast when packing B. This causes elements of B to be repeated (broadcast) in the packed copy of B so that subsequent vector loads will result in the element already being pre-broadcast into the vector register. +- Added support for selective packing to `gemmsup` (controlled via environment variables and/or the `rntm_t` object). (AMD) +- Fixed a bug in `sdsdot_sub()` that redundantly added the "alpha" scalar and a separate bug in the order of typecasting intermediate products in `sdsdot_()`. (Simon Lukas Märtens, Devin Matthews) +- Fixed an obscure bug in `bli_acquire_mpart_mdim()`/`bli_acquire_mpart_ndim()`. (Minh Quan Ho) +- Fixed a subtle and complicated bug that only manifested via the BLAS test drivers in the `generic` subconfiguration, and possibly any other subconfiguration that did not register complex-domain `gemm` ukernels, or registered ONLY real-domain ukernels as row-preferential. (Dave Love) +- Always use `sumsqv` to compute `normfv` instead of the "dot product trick" that was previously employed for performance reasons. (Roman Yurchak, Devin Matthews, and Isuru Fernando) +- Fixed bug in `thrinfo_t` debugging/printing code. + +Kernels: +- Implemented and registered an optimized `dgemm` microkernel for the `power9` kernel set. (Nicholai Tukanov) +- Pacify a `restrict` warning in the `gemmtrsm4m1` reference ukernel. (Dave Love, Devin Matthews) + +Build system: +- Fixed parsing in `vpu_count()` on some SkylakeX workstations. (Dave Love) +- Reimplemented `bli_cpuid_query()` for ARM to use `stdio`-based functions instead of `popen()`. (Dave Love) +- Use `-march=znver1` for clang on `zen2` subconfig. +- Updated `-march` flags for `sandybridge`, `haswell` subconfigurations to use newer syntax (e.g. `haswell` instead of `core-avx2` and `sandybridge` instead of `corei7-avx`. +- Correctly use `-qopenmp-simd` for reference kernels when compiling with icc. (Victor Eikjhout) +- Added `-march` support for select gcc version ranges where flag syntax changes or new flags are added. The ranges we identify are: versions older than 4.9.0; versions older than 6.1.0 (but newer than 4.9.0); versions older than 9.1.0 (but newer than 6.1.0). +- Use `-funsafe-math-optimizations` and `-ffp-contract=fast` for all reference kernels when using gcc or clang. +- Updated MC cache blocksizes used by `haswell` subconfig. +- Updated NC cache blocksizes used by `zen` subconfig. +- Fixed a typo in the context registration of the `cortexa53` subconfiguration in `bli_gks.c`. (Francisco Igual) +- Output a more informative error when the user manually targets a subconfiguration that configure places in the configuration blacklist. (Tze Meng Low) +- Set execute bits of shared library at install-time. (Adam J. Stewart) +- Added missing thread-related symbols for export to shared libraries. (Kyungmin Lee) +- Removed (finally) the `attic/windows` directory since we offer Windows DLL support via AppVeyor's build artifacts, and thus that directory was only likely confusing people. + +Testing: +- Fixed latent testsuite microkernel module bug for `power9` subconfig. (Jeff Hammond) +- Added `test/1m4m` driver directory for test drivers related to the 1m paper. +- Added libxsmm support to `test/sup drivers`. (Robert van de Geijn) +- Updated `.travis.yml` and `do_sde.sh` to automatically accept SDE license and download SDE directly from Intel. (Devin Matthews, Jeff Hammond) +- Updated standalone test drivers to iterate backwards through the specified problem space. This often helps avoid the situation whereby the CPU doesn't immediately throttle up to its maximum clock frequency, which can produce strange discontinuities (sharply rising "cliffs") in performance graphs. +- Pacify an unused variable warning in `blastest/f2c/lread.c`. (Jeff Hammond) +- Various other minor fixes/tweaks to test drivers. + +Documentation: +- Added libxsmm results to `docs/PerformanceSmall.md`. +- Added BLASFEO results to `docs/PerformanceSmall.md`. +- Added the page size and location of the performance drivers to `docs/Performance.md` and `docs/PerformanceSmall.md`. (Dave Love) +- Added notes to `docs/Multithreading.md` regarding the nuances of setting multithreading parameters the manual way vs. the automatic way. (Jérémie du Boisberranger) +- Added a section on reproduction to `docs/Performance.md` and `docs/PerformanceSmall.md`. (Dave Love) +- Documented Eigen `-march=native` hack in `docs/Performance.md` and `docs/PerformanceSmall.md`. (Sameer Agarwal) +- Inserted multithreading links and disclaimers to `BuildSystem.md`. (Jeff Diamond) +- Fixed typo in description for `bli_?axpy2v()` in `docs/BLISTypedAPI.md`. (Shmuel Levine) +- Added "How to Download BLIS" section to `README.md`. (Jeff Diamond) +- Various other minor documentation fixes. + +## Changes in 0.6.0 +June 3, 2019 + +Improvements present in 0.6.0: + +Framework: +- Implemented small/skinny/unpacked (sup) framework for accelerated level-3 performance when at least one matrix dimension is small (or very small). For now, only `dgemm` is optimized, and this new implementation currently only targets Intel Haswell through Coffee Lake, and AMD Zen-based Ryzen/Epyc. (The existing kernels should extend without significant modification to Zen2-based Ryzen/Epyc once they are available.) Also, multithreaded parallelism is not yet implemented, though application-level threading should be fine. (AMD) +- Changed function pointer usages of `void*` to new, typedef'ed type `void_fp`. +- Allow compile-time disabling of BLAS prototypes in BLIS, in case the application already has access to prototypes. +- In `bli_system.h`, define `_POSIX_C_SOURCE` to `200809L` if the macro is not already defined. This ensures that things such as pthreads are properly defined by an application that has `#include "blis.h"` but omits the definition of `_POSIX_C_SOURCE` from the command-line compiler options. (Christos Psarras) + +Kernels: +- None. + +Build system: +- Updated the way configure and the top-level Makefile handle installation prefixes (`prefix`, `exec_prefix`, `libdir`, `includedir`, `sharedir`) to better conform with GNU conventions. +- Improved clang version detection. (Isuru Fernando) +- Use pthreads on MinGW and Cygwin. (Isuru Fernando) + +Testing: +- Added Eigen support to test drivers in `test/3`. +- Fix inadvertently hidden `xerbla_()` in blastest drivers when building only shared libraries. (Isuru Fernando, M. Zhou) + +Documentation: +- Added `docs/PerformanceSmall.md` to showcase new BLIS small/skinny `dgemm` performance on Kaby Lake and Epyc. +- Added Eigen results (3.3.90) to performance graphs showcased in `docs/Performance.md`. +- Added BLIS thread factorization info to `docs/Performance.md`. + +## Changes in 0.5.2 +March 19, 2019 + +Improvements present in 0.5.2: + +Framework: +- Added support for IC loop parallelism to the `trsm` operation. +- Implemented a pool-based small block allocator and a corresponding `configure` option (enabled by default), which minimizes the number of calls to `malloc()` and `free()` for the purposes of allocating small blocks (on the order of 100 bytes). These small blocks are used by internal data structures, and the repeated allocation and freeing of these structures could, perhaps, cause memory fragmentation issues in certain application circumstances. This was never reproduced and observed, however, and remains entirely theoretical. Still, the sba should be no slower, and perhaps a little faster, than repeatedly calling `malloc()` and `free()` for these internal data structures. Also, the sba was designed to be thread-safe. (AMD) +- Refined and extended the output enabled by `--enable-mem-tracing`, which allows a developer to follow memory allocation and release performed by BLIS. +- Initialize error messages at compile-time rather than at runtime. (Minh Quan Ho) +- Fixed a potential situation whereby the multithreading parameters in a `rntm_t` object that is passed into an expert interface is ignored. +- Prevent a redefinition of `ftnlen` in the `f2c_types.h` in blastest. (Jeff Diamond) + +Kernels: +- Adjusted the cache blocksizes in the `zen` sub-configuration for `float`, `scomplex`, and `dcomplex` datatypes. The previous values, taken directly from the `haswell` subconfig, were merely meant to be reasonable placeholders until more suitable values were determined, as had already taken place for the `double` datatype. (AMD) +- Rewrote reference kernels in terms of simplified indexing annotated by the `#pragma omp simd` directive, which a compiler can use to vectorize certain constant-bounded loops. The `#pragma` is disabled via a preprocessor macro layer if the compiler is found by `configure` to not support `-fopenmp-simd`. (Devin Matthews, Jeff Hammond) + +Build system: +- Added symbol-export annotation macros to all of the function prototypes and global variable declarations for public symbols, and created a new `configure` option, `--export-shared=[public|all]`, that controls which symbols--only those that are meant to be public, or all symbols--are exported to the shared library. (Isuru Fernando) +- Standardized to using `-O3` in various subconfigs, and also `-funsafe-math-optimizations` for reference kernels. (Dave Love, Jeff Hammond) +- Disabled TBM, XOP, LWP instructions in all AMD subconfigs. (Devin Matthews) +- Fixed issues that prevented using BLIS on GNU Hurd. (M. Zhou) +- Relaxed python3 requirements to allow python 3.4 or later. Previously, python 3.5 or later was required if python3 was being used. (Dave Love) +- Added `thunderx2` sub-configuration. (Devangi Parikh) +- Added `power9` sub-configuration. For now, this subconfig only uses reference kernels. (Nicholai Tukanov) +- Fixed an issue with `configure` failing on OSes--including certain flavors of BSD--that contain a slash '/' character in the output of `uname -s`. (Isuru Fernando, M. Zhou) + +Testing: +- Renamed `test/3m4m` directory to `test/3`. +- Lots of updates and improvements to Makefiles, shell scripts, and matlab scripts in `test/3`. + +Documentation: +- Added a new `docs/Performance.md` document that showcases single-threaded, single-socket, and dual-socket performance results of `single`, `double`, `scomplex`, and `dcomplex` level-3 operations in BLIS, OpenBLAS, and MKL/ARMPL for Haswell, SkylakeX, ThunderX2, and Epyc hardware architectures. (Note: Other implementations such as Eigen and ATLAS may be added to these graphs in the future.) +- Updated `README.md` to include new language on external packages. (Dave Love) +- Updated `docs/Multithreading.md` to be more explicit about the fact that multithreading is disabled by default at configure-time, and the fact that BLIS will run executed single-threaded at runtime by default if no multithreaded specification is given. (M. Zhou) + ## Changes in 0.5.1 December 18, 2018 @@ -88,7 +451,7 @@ Kernels: Build system: - Added support for building Windows DLLs via AppVeyor [2], complete with a built-in implementation of pthreads for Windows, as well as an implementation of the `pthread_barrier_*()` APIs for use on OS X. (Isuru Fernando, Devin Matthews, Mathieu Poumeyrol, Matthew Honnibal) - Defined a `cortexa53` sub-configuration, which is similar to `cortexa57` except that it uses slightly different compiler flags. (Mathieu Poumeyrol) -- Added python version checking to configure script. +- Added python version checking to `configure` script. - Added a script to automate the regeneration of the symbols list file (now located in `build/libblis-symbols.def`). - Various tweaks in preparation for BLIS's inclusion within Debian. (M. Zhou) - Various fixes and cleanups. @@ -246,16 +609,16 @@ May 2, 2017 - Implemented the 1m method for inducing complex matrix multiplication. (Please see ACM TOMS publication ["Implementing high-performance complex matrix multiplication via the 1m method"](https://github.com/flame/blis#citations) for more details.) - Switched to simpler `trsm_r` implementation. - Relaxed constraints that `MC % NR = 0` and `NC % MR = 0`, as this was only needed for the more sophisticated `trsm_r` implementation. -- Automatic loop thread assignment. (Devin Matthews) -- Updates to `.travis.yml` configuration file. (Devin Matthews) +- Automatic loop thread assignment. (Devin Matthews) +- Updates to `.travis.yml` configuration file. (Devin Matthews) - Updates to non-default haswell microkernels. - Match storage format of the temporary micro-tiles in macrokernels to that of the microkernel storage preference for edge cases. -- Added support for Intel's Knight's Landing. (Devin Matthews) -- Added more flexible options to specify multithreading via the configure script. (Devin Matthews) -- OS X compatibility fixes. (Devin Matthews) -- Other small changes and fixes. +- Added support for Intel's Knight's Landing. (Devin Matthews) +- Added more flexible options to specify multithreading via the configure script. (Devin Matthews) +- OS X compatibility fixes. (Devin Matthews) +- Other small changes and fixes. -Also, thanks to Elmar Peise, Krzysztof Drewniak, and Francisco Igual for their contributions in reporting/fixing certain bugs that were addressed in this version. +Also, thanks to Elmar Peise, Krzysztof Drewniak, and Francisco Igual for their contributions in reporting/fixing certain bugs that were addressed in this version. ## Changes in 0.2.1 October 5, 2016 @@ -439,7 +802,7 @@ While neither `bli_config.h` nor `bli_kernel.h` has changed formats since 0.0.7, ## Changes in 0.0.7 April 30, 2013 -This version incorporates many small fixes and feature enhancements made during our SC13 collaboration. +This version incorporates many small fixes and feature enhancements made during our SC13 collaboration. ## Changes in 0.0.6 April 13, 2013 @@ -478,7 +841,7 @@ The compatibility layer is enabled via a configuration option in `bl2_config.h`. ## Changes in 0.0.2 February 11, 2013 -Most notably, this version contains the new test suite I've been working on for the last month. +Most notably, this version contains the new test suite I've been working on for the last month. What is the test suite? It is a highly configurable test driver that allows one to test an arbitrary set of BLIS operations, with an arbitrary set of parameter combinations, and matrix/vector storage formats, as well as whichever datatypes you are interested in. (For now, only homogeneous datatyping is supported, which is what most people want.) You can also specify an arbitrary problem size range with arbitrary increments, and arbitrary ratios between dimensions (or anchor a dimension to a single value), and you can output directly to files which store the output in matlab syntax, which makes it easy to generate performance graphs. diff --git a/docs/Sandboxes.md b/docs/Sandboxes.md index 6968b43bf9..cbc0add53e 100644 --- a/docs/Sandboxes.md +++ b/docs/Sandboxes.md @@ -17,13 +17,9 @@ Simply put, a sandbox in BLIS provides an alternative implementation to the `gemm` operation. To get a little more specific, a sandbox provides an alternative implementation -to the function `bli_gemmnat()`, which is the object-based API call for -computing the `gemm` operation via native execution. - -**Note**: Native execution simply means that an induced method will not be used. -It's what you probably already think of when you think of implementing the -`gemm` operation: a series of loops around an optimized (usually assembly-based) -microkernel with some packing functions thrown in at various levels. +to the function `bli_gemm_ex()`, which is the +[expert interface](BLISObjectAPI.md##basic-vs-expert-interfaces) for calling the +[object-based API](BLISObjectAPI.md#gemm) for the `gemm` operation. Why sandboxes? Sometimes you want to experiment with tweaks or changes to the `gemm` operation, but you want to do so in a simple environment rather than @@ -37,29 +33,35 @@ utility functions. To enable a sandbox at configure-time, you simply specify it as an option to `configure`. Either of the following usages are accepted: ``` -$ ./configure --enable-sandbox=ref99 auto -$ ./configure -s ref99 auto +$ ./configure --enable-sandbox=gemmlike auto +$ ./configure -s gemmlike auto ``` -Here, we tell `configure` that we want to use the `ref99` sandbox, which -corresponds to a sub-directory of `sandbox` named `ref99`. (Reminder: the +Here, we tell `configure` that we want to use the `gemmlike` sandbox, which +corresponds to a sub-directory of `sandbox` named `gemmlike`. (Reminder: the `auto` argument is the configuration target and thus unrelated to -sandboxes.) As `configure` runs, you should get output that includes lines +sandboxes.) + +NOTE: Using your own sandbox implementation means that BLIS will call your +sandbox for *all* problem sizes and shapes, for *all* datatypes supported +by BLIS. If you intend to only implement a subset of this functionality +within your sandbox, you should be sure to redirect execution back into +the core framework for the parts that you don't wish to reimplement yourself. + +As `configure` runs, you should get output that includes lines similar to: ``` configure: configuring for alternate gemm implementation: -configure: sandbox/ref99 +configure: sandbox/gemmlike ``` And when you build BLIS, the last files to be compiled will be the source code in the specified sandbox: ``` -Compiling obj/haswell/sandbox/ref99/blx_gemm_front.o ('haswell' CFLAGS for sandboxes) -Compiling obj/haswell/sandbox/ref99/blx_gemm_int.o ('haswell' CFLAGS for sandboxes) -Compiling obj/haswell/sandbox/ref99/base/blx_blksz.o ('haswell' CFLAGS for sandboxes) -Compiling obj/haswell/sandbox/ref99/cntl/blx_gemm_cntl.o ('haswell' CFLAGS for sandboxes) +Compiling obj/haswell/sandbox/gemmlike/bls_gemm.o ('haswell' CFLAGS for sandboxes) +Compiling obj/haswell/sandbox/gemmlike/bls_gemm_bp_var1.o ('haswell' CFLAGS for sandboxes) ... ``` That's it! After the BLIS library is built, it will contain your chosen -sandbox's implementation of `bli_gemmnat()` instead of the default +sandbox's implementation of `bli_gemm_ex()` instead of the default BLIS implementation. ## Sandbox rules @@ -79,16 +81,19 @@ will be found! 2. Your sandbox must be written in C99 or C++11. If you write your sandbox in C++11, you must use one of the BLIS-approved file extensions for your source files (`.cc`, `.cpp`, `.cxx`) and your header files (`.hh`, `.hpp`, `.hxx`). -Note that `blis.h` -already contains all of its definitions inside of an `extern "C"` block, so -you should be able to `#include "blis.h"` from your C++11 source code without -any issues. +Note that `blis.h` already contains all of its definitions inside of an +`extern "C"` block, so you should be able to `#include "blis.h"` from your +C++11 source code without any issues. -3. All of your code to replace BLIS's default implementation of `bli_gemmnat()` +3. All of your code to replace BLIS's default implementation of `bli_gemm_ex()` should reside in the named sandbox directory, or some directory therein. -(Obviously.) For example, the "reference" sandbox is located in -`sandbox/ref99`. All of the code associated with this sandbox will be -contained within `sandbox/ref99`. +(Obviously.) For example, the "gemmlike" sandbox is located in +`sandbox/gemmlike`. All of the code associated with this sandbox will be +contained within `sandbox/gemmlike`. Note that you absolutely *may* include +additional code and interfaces within the sandbox, if you wish -- code and +interfaces that are not directly or indirectly needed for satisfying the +the "contract" set forth by the sandbox (i.e., including a local definition +of`bli_gemm_ex()`). 4. The *only* header file that is required of your sandbox is `bli_sandbox.h`. It must be named `bli_sandbox.h` because `blis.h` will `#include` this file @@ -102,17 +107,18 @@ you should only place things (e.g. prototypes or type definitions) in (b) an *application* that calls your sandbox-enabled BLIS library. Usually, neither of these situations will require any of your local definitions since those local definitions are only needed to define your sandbox -implementation of `bli_gemmnat()`, and this function is already prototyped by -BLIS. +implementation of `bli_gemm_ex()`, and this function is already prototyped by +BLIS. *But if you are adding additional APIs and/or operations to the sandbox +that are unrelated to `bli_gemm_ex()`, then you'll want to `#include` those +function prototypes from within `bli_sandbox.h`* -5. Your definition of `bli_gemmnat()` should be the **only function you define** +5. Your definition of `bli_gemm_ex()` should be the **only function you define** in your sandbox that begins with `bli_`. If you define other functions that begin with `bli_`, you risk a namespace collision with existing framework functions. To guarantee safety, please prefix your locally-defined sandbox -functions with another prefix. Here, in the `ref99` sandbox, we use the prefix -`blx_`. (The `x` is for sandbox. Or experimental.) Also, please avoid the -prefix `bla_` since that prefix is also used in BLIS for BLAS compatibility -functions. +functions with another prefix. Here, in the `gemmlike` sandbox, we use the prefix +`bls_`. (The `s` is for sandbox.) Also, please avoid the prefix `bla_` since that +prefix is also used in BLIS for BLAS compatibility functions. If you follow these rules, you will be much more likely to have a pleasant experience integrating your BLIS sandbox into the larger framework. @@ -129,9 +135,9 @@ For example, with a BLIS sandbox you **can** do the following kinds of things: kernels, which can already be customized within each sub-configuration); - try inlining your functions manually; - pivot away from using `obj_t` objects at higher algorithmic level (such as - immediately after calling `bli_gemmnat()`) to try to avoid some overhead; + immediately after calling `bli_gemm_ex()`) to try to avoid some overhead; - create experimental implementations of new BLAS-like operations (provided - that you also provide an implementation of `bli_gemmnat()`). + that you also provide an implementation of `bli_gemm_ex()`). You **cannot**, however, use a sandbox to do the following kinds of things: - define new datatypes (half-precision, quad-precision, short integer, etc.) @@ -149,8 +155,8 @@ Another important limitation is the fact that the build system currently uses # Example framework CFLAGS used by 'haswell' sub-configuration -O3 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L -I./include/haswell -I./frame/3/ --I./frame/ind/ukernels/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ --I./frame/include -DBLIS_VERSION_STRING=\"0.3.2-51\" +-I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.3.2-51\" ``` which are likely more general-purpose than the `CFLAGS` used for, say, optimized kernels or even reference kernels. @@ -158,8 +164,8 @@ optimized kernels or even reference kernels. # Example optimized kernel CFLAGS used by 'haswell' sub-configuration -O3 -mavx2 -mfma -mfpmath=sse -march=core-avx2 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L -I./include/haswell --I./frame/3/ -I./frame/ind/ukernels/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ --I./frame/include -DBLIS_VERSION_STRING=\"0.3.2-51\" +-I./frame/3/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.3.2-51\" ``` (To see precisely which flags are being employed for any given file, enable verbosity at compile-time via `make V=1`.) Compiling sandboxes with these more @@ -194,7 +200,7 @@ enabled in `input.general`. However, if those options *are* enabled and BLIS was built with mixed datatype support, then BLIS assumes that the implementation of `gemm` will support mixing of datatypes. BLIS *must* assume this, because there's no way for it to confirm at runtime that an implementation was written -to support mixing datatypes. Note that even the `ref99` sandbox included with +to support mixing datatypes. Note that even the `gemmlike` sandbox included with BLIS does not support mixed-datatype computation. ## Conclusion diff --git a/docs/Testsuite.md b/docs/Testsuite.md index 917a7e4a7c..7c4893d04f 100644 --- a/docs/Testsuite.md +++ b/docs/Testsuite.md @@ -128,11 +128,6 @@ sdcz # Datatype(s) to test: 300 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test -1 # 3mh ('1' = enable; '0' = disable) -1 # 3m1 ('1' = enable; '0' = disable) -1 # 4mh ('1' = enable; '0' = disable) -1 # 4m1b ('1' = enable; '0' = disable) -1 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: @@ -155,7 +150,7 @@ _**Vector storage scheme.**_ Similar to the matrix storage scheme string, this s _**Test all combinations of storage schemes?**_ Enabling this option causes all combinations of storage schemes to be tested. For example, if the option is disabled, a matrix storage scheme string of `cr` would cause the `gemm` test module to test execution where all matrix operands are column-stored, and then where all matrix operands are row-stored. Enabling this option with the same matrix storage string (`cr`) would cause the test suite to test `gemm` under all eight scenarios where the three `gemm` matrix operands are either column-stored or row-stored. -_**Perform all tests with alignment?**_ Disabling this option causes the leading dimension (row or column stride) of test matrices to **not** be aligned according to `BLIS_HEAP_STRIDE_ALIGN_SIZE`, which defaults to `BLIS_SIMD_ALIGN_SIZE`, which defaults to `BLIS_SIMD_SIZE`, which defaults to 64 (bytes). (If any of these values is set to a non-default value, it would be in `bli_family_.h` where `` is the configuration family.) Sometimes it's useful to disable leading dimension alignment in order to test certain aspects of BLIS that need to handle computing with unaligned user data, such as level-1v and level-1f kernels. +_**Perform all tests with alignment?**_ Disabling this option causes the leading dimension (row or column stride) of test matrices to **not** be aligned according to `BLIS_HEAP_STRIDE_ALIGN_SIZE`, which defaults to `BLIS_SIMD_ALIGN_SIZE`, which defaults to `BLIS_SIMD_MAX_SIZE`, which defaults to 64 (bytes). (If any of these values is set to a non-default value, it would be in `bli_family_.h` where `` is the configuration family.) Sometimes it's useful to disable leading dimension alignment in order to test certain aspects of BLIS that need to handle computing with unaligned user data, such as level-1v and level-1f kernels. _**Randomize vectors and matrices.**_ The default randomization method uses real values on the interval [-1,1]. However, we offer an alternate randomization using powers of two in a narrow precision range, which is more likely to result in test residuals exactly equal to zero. This method is somewhat niche/experimental and most people should use random values on the [-1,1] interval. @@ -169,7 +164,7 @@ _**Test gemm with mixed-precision operands?**_ This boolean determines whether ` _**Problem size.**_ These values determine the first problem size to test, the maximum problem size to test, and the increment between problem sizes. Note that the maximum problem size only bounds the range of problem sizes; it is not guaranteed to be tested. Example: If the initial problem size is 128, the maximum is 1000, and the increment is 64, then the last problem size to be tested will be 960. -_**Complex level-3 implementations to test.**_ With the exception of the switch marked `native`, these switches control whether experimental complex domain implementations are tested (when applicable). These implementations employ induced methods complex matrix multiplication and apply to some (though not all) of the level-3 operations. If you don't know what these are, you can ignore them. The `native` switch corresponds to native execution of complex domain level-3 operations, which we test by default. We also test the `1m` method, since it is the induced method of choice when complex microkernels are not available. Note that all of these induced method tests (including `native`) are automatically disabled if the `c` and `z` datatypes are disabled. +_**Complex level-3 implementations to test.**_ This section lists which complex domain implementations of level-3 operations are tested. If you don't know what these are, you can ignore them. The `native` switch corresponds to native execution of complex domain level-3 operations, which we test by default. We also test the `1m` method, since it is the induced method of choice when optimized complex microkernels are not available. Note that all of these induced method tests (including `native`) are automatically disabled if the `c` and `z` datatypes are disabled. _**Simulate application-level threading.**_ This setting specifies the number of threads the testsuite will spawn, and is meant to allow the user to exercise BLIS as a multithreaded application might if it were to make multiple concurrent calls to BLIS operations. (Note that the threading controlled by this option is orthogonal to, and has no effect on, whatever multithreading may be employed _within_ BLIS, as specified by the environment variables described in the [Multithreading](Multithreading.md) documentation.) When this option is set to 1, the testsuite is run with only one thread. When set to n > 1 threads, the spawned threads will parallelize (in round-robin fashion) the total set of tests specified by the testsuite input files, executing them in roughly the same order as that of a sequential execution. diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf new file mode 100644 index 0000000000..4d27944170 Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png new file mode 100644 index 0000000000..f51548effb Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf new file mode 100644 index 0000000000..845dfaf862 Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png new file mode 100644 index 0000000000..08e46c6723 Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png differ diff --git a/docs/graphs/large/l3_perf_a64fx_nt1.pdf b/docs/graphs/large/l3_perf_a64fx_nt1.pdf new file mode 100644 index 0000000000..97a31560a1 Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_nt1.png b/docs/graphs/large/l3_perf_a64fx_nt1.png new file mode 100644 index 0000000000..0b7c2d72aa Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_nt1.png differ diff --git a/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf new file mode 100644 index 0000000000..3b80889a9a Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf differ diff --git a/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.png b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.png new file mode 100644 index 0000000000..08e28fd0df Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.png differ diff --git a/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf new file mode 100644 index 0000000000..d55f37bdc7 Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf differ diff --git a/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.png b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.png new file mode 100644 index 0000000000..e3fb023af8 Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.png differ diff --git a/docs/graphs/large/l3_perf_has_nt1.pdf b/docs/graphs/large/l3_perf_has_nt1.pdf new file mode 100644 index 0000000000..64da9d1603 Binary files /dev/null and b/docs/graphs/large/l3_perf_has_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_has_nt1.png b/docs/graphs/large/l3_perf_has_nt1.png new file mode 100644 index 0000000000..12651513ff Binary files /dev/null and b/docs/graphs/large/l3_perf_has_nt1.png differ diff --git a/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf new file mode 100644 index 0000000000..517aee9ed1 Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf differ diff --git a/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png new file mode 100644 index 0000000000..c77159dd5a Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png differ diff --git a/docs/graphs/large/l3_perf_nn1_nt1.pdf b/docs/graphs/large/l3_perf_nn1_nt1.pdf new file mode 100644 index 0000000000..6c5ff9f063 Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_nn1_nt1.png b/docs/graphs/large/l3_perf_nn1_nt1.png new file mode 100644 index 0000000000..750ccf0997 Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_nt1.png differ diff --git a/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.pdf b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.pdf new file mode 100644 index 0000000000..0896251578 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.pdf differ diff --git a/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.png b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.png new file mode 100644 index 0000000000..cf970de368 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.png differ diff --git a/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.pdf b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.pdf new file mode 100644 index 0000000000..eca573ccf2 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.pdf differ diff --git a/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.png b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.png new file mode 100644 index 0000000000..561357a718 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.png differ diff --git a/docs/graphs/large/l3_perf_skx_nt1.pdf b/docs/graphs/large/l3_perf_skx_nt1.pdf new file mode 100644 index 0000000000..e0e4c74b6b Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_skx_nt1.png b/docs/graphs/large/l3_perf_skx_nt1.png new file mode 100644 index 0000000000..02ad841c06 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_nt1.png differ diff --git a/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf new file mode 100644 index 0000000000..352d0556ca Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf differ diff --git a/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.png b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.png new file mode 100644 index 0000000000..1b8f231922 Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.png differ diff --git a/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf new file mode 100644 index 0000000000..c25ea9eee8 Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf differ diff --git a/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.png b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.png new file mode 100644 index 0000000000..87b039886e Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.png differ diff --git a/docs/graphs/large/l3_perf_tx2_nt1.pdf b/docs/graphs/large/l3_perf_tx2_nt1.pdf new file mode 100644 index 0000000000..66c808c9c9 Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_tx2_nt1.png b/docs/graphs/large/l3_perf_tx2_nt1.png new file mode 100644 index 0000000000..058bef36bf Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_nt1.png differ diff --git a/docs/graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.pdf b/docs/graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.pdf new file mode 100644 index 0000000000..4e7c5698aa Binary files /dev/null and b/docs/graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.pdf differ diff --git a/docs/graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.png b/docs/graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.png new file mode 100644 index 0000000000..fb0ce614f3 Binary files /dev/null and b/docs/graphs/large/l3_perf_zen2_jc4ic4jr4_nt64.png differ diff --git a/docs/graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.pdf b/docs/graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.pdf new file mode 100644 index 0000000000..d1d53bc73f Binary files /dev/null and b/docs/graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.pdf differ diff --git a/docs/graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.png b/docs/graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.png new file mode 100644 index 0000000000..8fd68d089c Binary files /dev/null and b/docs/graphs/large/l3_perf_zen2_jc8ic4jr4_nt128.png differ diff --git a/docs/graphs/large/l3_perf_zen2_nt1.pdf b/docs/graphs/large/l3_perf_zen2_nt1.pdf new file mode 100644 index 0000000000..33f270cc81 Binary files /dev/null and b/docs/graphs/large/l3_perf_zen2_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_zen2_nt1.png b/docs/graphs/large/l3_perf_zen2_nt1.png new file mode 100644 index 0000000000..fe87356e58 Binary files /dev/null and b/docs/graphs/large/l3_perf_zen2_nt1.png differ diff --git a/docs/graphs/large/l3_perf_zen_jc1ic8jr4_nt32.pdf b/docs/graphs/large/l3_perf_zen_jc1ic8jr4_nt32.pdf new file mode 100644 index 0000000000..7fbf4abda4 Binary files /dev/null and b/docs/graphs/large/l3_perf_zen_jc1ic8jr4_nt32.pdf differ diff --git a/docs/graphs/large/l3_perf_zen_jc1ic8jr4_nt32.png b/docs/graphs/large/l3_perf_zen_jc1ic8jr4_nt32.png new file mode 100644 index 0000000000..aa12be210b Binary files /dev/null and b/docs/graphs/large/l3_perf_zen_jc1ic8jr4_nt32.png differ diff --git a/docs/graphs/large/l3_perf_zen_jc2ic8jr4_nt64.pdf b/docs/graphs/large/l3_perf_zen_jc2ic8jr4_nt64.pdf new file mode 100644 index 0000000000..d7250eff6f Binary files /dev/null and b/docs/graphs/large/l3_perf_zen_jc2ic8jr4_nt64.pdf differ diff --git a/docs/graphs/large/l3_perf_zen_jc2ic8jr4_nt64.png b/docs/graphs/large/l3_perf_zen_jc2ic8jr4_nt64.png new file mode 100644 index 0000000000..168de35383 Binary files /dev/null and b/docs/graphs/large/l3_perf_zen_jc2ic8jr4_nt64.png differ diff --git a/docs/graphs/large/l3_perf_zen_nt1.pdf b/docs/graphs/large/l3_perf_zen_nt1.pdf new file mode 100644 index 0000000000..4b34f4d274 Binary files /dev/null and b/docs/graphs/large/l3_perf_zen_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_zen_nt1.png b/docs/graphs/large/l3_perf_zen_nt1.png new file mode 100644 index 0000000000..f1a2ef5a6e Binary files /dev/null and b/docs/graphs/large/l3_perf_zen_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_has_nt1.pdf b/docs/graphs/sup/dgemm_ccc_has_nt1.pdf new file mode 100644 index 0000000000..75a7502abf Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_has_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_has_nt1.png b/docs/graphs/sup/dgemm_ccc_has_nt1.png new file mode 100644 index 0000000000..527cca0c76 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_has_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_has_nt12.pdf b/docs/graphs/sup/dgemm_ccc_has_nt12.pdf new file mode 100644 index 0000000000..b598c83f9d Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_has_nt12.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_has_nt12.png b/docs/graphs/sup/dgemm_ccc_has_nt12.png new file mode 100644 index 0000000000..e50a72753a Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_has_nt12.png differ diff --git a/docs/graphs/sup/dgemm_ccc_kbl_nt1.pdf b/docs/graphs/sup/dgemm_ccc_kbl_nt1.pdf new file mode 100644 index 0000000000..f30c4fac93 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_kbl_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_kbl_nt1.png b/docs/graphs/sup/dgemm_ccc_kbl_nt1.png new file mode 100644 index 0000000000..fc86fb0c24 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_kbl_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_kbl_nt4.pdf b/docs/graphs/sup/dgemm_ccc_kbl_nt4.pdf new file mode 100644 index 0000000000..d022b70764 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_kbl_nt4.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_kbl_nt4.png b/docs/graphs/sup/dgemm_ccc_kbl_nt4.png new file mode 100644 index 0000000000..3adefb6534 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_kbl_nt4.png differ diff --git a/docs/graphs/sup/dgemm_ccc_zen2_nt1.pdf b/docs/graphs/sup/dgemm_ccc_zen2_nt1.pdf new file mode 100644 index 0000000000..f4c67930b0 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen2_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_zen2_nt1.png b/docs/graphs/sup/dgemm_ccc_zen2_nt1.png new file mode 100644 index 0000000000..175fccd133 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen2_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_zen2_nt32.pdf b/docs/graphs/sup/dgemm_ccc_zen2_nt32.pdf new file mode 100644 index 0000000000..6cd263e930 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen2_nt32.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_zen2_nt32.png b/docs/graphs/sup/dgemm_ccc_zen2_nt32.png new file mode 100644 index 0000000000..426286783f Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen2_nt32.png differ diff --git a/docs/graphs/sup/dgemm_ccc_zen_nt1.pdf b/docs/graphs/sup/dgemm_ccc_zen_nt1.pdf new file mode 100644 index 0000000000..eafba82d46 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_zen_nt1.png b/docs/graphs/sup/dgemm_ccc_zen_nt1.png new file mode 100644 index 0000000000..ceeb084268 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_zen_nt32.pdf b/docs/graphs/sup/dgemm_ccc_zen_nt32.pdf new file mode 100644 index 0000000000..f2137eaba9 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen_nt32.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_zen_nt32.png b/docs/graphs/sup/dgemm_ccc_zen_nt32.png new file mode 100644 index 0000000000..09958337f4 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_zen_nt32.png differ diff --git a/docs/graphs/sup/dgemm_rrr_has_nt1.pdf b/docs/graphs/sup/dgemm_rrr_has_nt1.pdf new file mode 100644 index 0000000000..eff4795794 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_has_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_has_nt1.png b/docs/graphs/sup/dgemm_rrr_has_nt1.png new file mode 100644 index 0000000000..084369e0eb Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_has_nt1.png differ diff --git a/docs/graphs/sup/dgemm_rrr_has_nt12.pdf b/docs/graphs/sup/dgemm_rrr_has_nt12.pdf new file mode 100644 index 0000000000..bd9ad99b29 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_has_nt12.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_has_nt12.png b/docs/graphs/sup/dgemm_rrr_has_nt12.png new file mode 100644 index 0000000000..a404a2eda3 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_has_nt12.png differ diff --git a/docs/graphs/sup/dgemm_rrr_kbl_nt1.pdf b/docs/graphs/sup/dgemm_rrr_kbl_nt1.pdf new file mode 100644 index 0000000000..ba3c87d886 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_kbl_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_kbl_nt1.png b/docs/graphs/sup/dgemm_rrr_kbl_nt1.png new file mode 100644 index 0000000000..96386bf800 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_kbl_nt1.png differ diff --git a/docs/graphs/sup/dgemm_rrr_kbl_nt4.pdf b/docs/graphs/sup/dgemm_rrr_kbl_nt4.pdf new file mode 100644 index 0000000000..2fe3fddf0e Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_kbl_nt4.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_kbl_nt4.png b/docs/graphs/sup/dgemm_rrr_kbl_nt4.png new file mode 100644 index 0000000000..535ce244dc Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_kbl_nt4.png differ diff --git a/docs/graphs/sup/dgemm_rrr_zen2_nt1.pdf b/docs/graphs/sup/dgemm_rrr_zen2_nt1.pdf new file mode 100644 index 0000000000..b8105e6b01 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen2_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_zen2_nt1.png b/docs/graphs/sup/dgemm_rrr_zen2_nt1.png new file mode 100644 index 0000000000..f1e094247b Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen2_nt1.png differ diff --git a/docs/graphs/sup/dgemm_rrr_zen2_nt32.pdf b/docs/graphs/sup/dgemm_rrr_zen2_nt32.pdf new file mode 100644 index 0000000000..9a6ea85e27 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen2_nt32.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_zen2_nt32.png b/docs/graphs/sup/dgemm_rrr_zen2_nt32.png new file mode 100644 index 0000000000..f2ea0835e5 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen2_nt32.png differ diff --git a/docs/graphs/sup/dgemm_rrr_zen_nt1.pdf b/docs/graphs/sup/dgemm_rrr_zen_nt1.pdf new file mode 100644 index 0000000000..896844d3cf Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_zen_nt1.png b/docs/graphs/sup/dgemm_rrr_zen_nt1.png new file mode 100644 index 0000000000..ada4c17695 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen_nt1.png differ diff --git a/docs/graphs/sup/dgemm_rrr_zen_nt32.pdf b/docs/graphs/sup/dgemm_rrr_zen_nt32.pdf new file mode 100644 index 0000000000..75f8eb62a2 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen_nt32.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_zen_nt32.png b/docs/graphs/sup/dgemm_rrr_zen_nt32.png new file mode 100644 index 0000000000..7607c91cb8 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_zen_nt32.png differ diff --git a/docs/graphs/sup/sgemm_ccc_zen2_nt1.pdf b/docs/graphs/sup/sgemm_ccc_zen2_nt1.pdf new file mode 100644 index 0000000000..7651cd8a62 Binary files /dev/null and b/docs/graphs/sup/sgemm_ccc_zen2_nt1.pdf differ diff --git a/docs/graphs/sup/sgemm_ccc_zen2_nt1.png b/docs/graphs/sup/sgemm_ccc_zen2_nt1.png new file mode 100644 index 0000000000..f4995fa57c Binary files /dev/null and b/docs/graphs/sup/sgemm_ccc_zen2_nt1.png differ diff --git a/docs/graphs/sup/sgemm_ccc_zen2_nt32.pdf b/docs/graphs/sup/sgemm_ccc_zen2_nt32.pdf new file mode 100644 index 0000000000..1e44d9f796 Binary files /dev/null and b/docs/graphs/sup/sgemm_ccc_zen2_nt32.pdf differ diff --git a/docs/graphs/sup/sgemm_ccc_zen2_nt32.png b/docs/graphs/sup/sgemm_ccc_zen2_nt32.png new file mode 100644 index 0000000000..a072a2f2be Binary files /dev/null and b/docs/graphs/sup/sgemm_ccc_zen2_nt32.png differ diff --git a/docs/graphs/sup/sgemm_rrr_zen2_nt1.pdf b/docs/graphs/sup/sgemm_rrr_zen2_nt1.pdf new file mode 100644 index 0000000000..762944e3c5 Binary files /dev/null and b/docs/graphs/sup/sgemm_rrr_zen2_nt1.pdf differ diff --git a/docs/graphs/sup/sgemm_rrr_zen2_nt1.png b/docs/graphs/sup/sgemm_rrr_zen2_nt1.png new file mode 100644 index 0000000000..fadd4e39f4 Binary files /dev/null and b/docs/graphs/sup/sgemm_rrr_zen2_nt1.png differ diff --git a/docs/graphs/sup/sgemm_rrr_zen2_nt32.pdf b/docs/graphs/sup/sgemm_rrr_zen2_nt32.pdf new file mode 100644 index 0000000000..ea9f9f59e9 Binary files /dev/null and b/docs/graphs/sup/sgemm_rrr_zen2_nt32.pdf differ diff --git a/docs/graphs/sup/sgemm_rrr_zen2_nt32.png b/docs/graphs/sup/sgemm_rrr_zen2_nt32.png new file mode 100644 index 0000000000..9b4743d3ef Binary files /dev/null and b/docs/graphs/sup/sgemm_rrr_zen2_nt32.png differ diff --git a/examples/oapi/08level2.c b/examples/oapi/08level2.c index 24b1402887..09e61722d4 100644 --- a/examples/oapi/08level2.c +++ b/examples/oapi/08level2.c @@ -246,10 +246,11 @@ int main( int argc, char** argv ) // displaying junk values in the unstored triangle. bli_setm( &BLIS_ZERO, &a ); - // Mark matrix 'a' as triangular and stored in the lower triangle, and - // then randomize that lower triangle. + // Mark matrix 'a' as triangular, stored in the lower triangle, and + // having a non-unit diagonal. Then randomize that lower triangle. bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( BLIS_LOWER, &a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &a ); bli_randm( &a ); bli_printm( "a: randomized (zeros in upper triangle)", &a, "%4.1f", "" ); @@ -288,10 +289,11 @@ int main( int argc, char** argv ) // displaying junk values in the unstored triangle. bli_setm( &BLIS_ZERO, &a ); - // Mark matrix 'a' as triangular and stored in the lower triangle, and - // then randomize that lower triangle. + // Mark matrix 'a' as triangular, stored in the lower triangle, and + // having a non-unit diagonal. Then randomize that lower triangle. bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( BLIS_LOWER, &a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &a ); bli_randm( &a ); // Load the diagonal. By setting the diagonal to something of greater diff --git a/examples/oapi/09level3.c b/examples/oapi/09level3.c index 70839fadb0..27ec78c52b 100644 --- a/examples/oapi/09level3.c +++ b/examples/oapi/09level3.c @@ -244,10 +244,11 @@ int main( int argc, char** argv ) // displaying junk values in the unstored triangle. bli_setm( &BLIS_ZERO, &a ); - // Mark matrix 'a' as triangular and stored in the lower triangle, and - // then randomize that lower triangle. + // Mark matrix 'a' as triangular, stored in the lower triangle, and + // having a non-unit diagonal. Then randomize that lower triangle. bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( BLIS_LOWER, &a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &a ); bli_randm( &a ); bli_printm( "a: randomized (zeros in upper triangle)", &a, "%4.1f", "" ); @@ -290,10 +291,11 @@ int main( int argc, char** argv ) // displaying junk values in the unstored triangle. bli_setm( &BLIS_ZERO, &a ); - // Mark matrix 'a' as triangular and stored in the lower triangle, and - // then randomize that lower triangle. + // Mark matrix 'a' as triangular, stored in the lower triangle, and + // having a non-unit diagonal. Then randomize that lower triangle. bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( BLIS_LOWER, &a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &a ); bli_randm( &a ); // Load the diagonal. By setting the diagonal to something of greater diff --git a/examples/oapi/Makefile b/examples/oapi/Makefile index 64dbf20dd6..f12ca227be 100644 --- a/examples/oapi/Makefile +++ b/examples/oapi/Makefile @@ -114,7 +114,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Binary executable name. TEST_BINS := 00obj_basic.x \ diff --git a/examples/oapi/README b/examples/oapi/README index 5823b29456..991bbf28c4 100644 --- a/examples/oapi/README +++ b/examples/oapi/README @@ -22,8 +22,8 @@ or by setting the same variable as part of the make command: make BLIS_INSTALL_PATH=/usr/local Once the executable files have been built, we recommend reading the code in -one terminal window alongside the executable output in another. This will -help you see the effects of each section of code. +one terminal window alongside the executable output in another terminal. +This will help you see the effects of each section of code. This tutorial is not exhaustive or complete; several object API functions were omitted (mostly for brevity's sake) and thus more examples could be diff --git a/examples/tapi/00level1v.c b/examples/tapi/00level1v.c index 31efcd8233..e27450714e 100644 --- a/examples/tapi/00level1v.c +++ b/examples/tapi/00level1v.c @@ -175,7 +175,7 @@ int main( int argc, char** argv ) free( y ); free( z ); free( w ); - free( z ); + free( a ); return 0; } diff --git a/examples/tapi/Makefile b/examples/tapi/Makefile index 1de4acc132..83330d38bf 100644 --- a/examples/tapi/Makefile +++ b/examples/tapi/Makefile @@ -102,7 +102,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Binary executable name. TEST_BINS := 00level1v.x \ diff --git a/examples/tapi/README b/examples/tapi/README index ad09787820..4460d3a8d6 100644 --- a/examples/tapi/README +++ b/examples/tapi/README @@ -22,8 +22,8 @@ or by setting the same variable as part of the make command: make BLIS_INSTALL_PATH=/usr/local Once the executable files have been built, we recommend reading the code in -one terminal window alongside the executable output in another. This will -help you see the effects of each section of code. +one terminal window alongside the executable output in another terminal. +This will help you see the effects of each section of code. This tutorial is not exhaustive or complete; many typed API functions were omitted (mostly for brevity's sake) and thus more examples could be diff --git a/frame/0/bli_l0_check.c b/frame/0/bli_l0_check.c index 65eeda1b7f..966f0c6aaa 100644 --- a/frame/0/bli_l0_check.c +++ b/frame/0/bli_l0_check.c @@ -87,6 +87,7 @@ void PASTEMAC(opname,_check) \ GENFRONT( absqsc ) GENFRONT( normfsc ) +// ----------------------------------------------------------------------------- void bli_getsc_check ( @@ -352,3 +353,37 @@ void bli_l0_xx2sc_check bli_check_error_code( e_val ); } +void bli_l0_xxbsc_check + ( + obj_t* chi, + obj_t* psi, + bool* is_eq + ) +{ + err_t e_val; + + // Check object datatypes. + + e_val = bli_check_noninteger_object( chi ); + bli_check_error_code( e_val ); + + e_val = bli_check_noninteger_object( psi ); + bli_check_error_code( e_val ); + + // Check object dimensions. + + e_val = bli_check_scalar_object( chi ); + bli_check_error_code( e_val ); + + e_val = bli_check_scalar_object( psi ); + bli_check_error_code( e_val ); + + // Check object buffers (for non-NULLness). + + e_val = bli_check_object_buffer( chi ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( psi ); + bli_check_error_code( e_val ); +} + diff --git a/frame/0/bli_l0_check.h b/frame/0/bli_l0_check.h index 262679aeb6..f495866c62 100644 --- a/frame/0/bli_l0_check.h +++ b/frame/0/bli_l0_check.h @@ -129,7 +129,6 @@ void PASTEMAC(opname,_check) \ GENTPROT( zipsc ) - // ----------------------------------------------------------------------------- void bli_l0_xsc_check @@ -148,3 +147,10 @@ void bli_l0_xx2sc_check obj_t* chi, obj_t* norm ); + +void bli_l0_xxbsc_check + ( + obj_t* chi, + obj_t* psi, + bool* is_eq + ); diff --git a/frame/0/bli_l0_ft.h b/frame/0/bli_l0_ft.h index 47d47276aa..b90e35eb59 100644 --- a/frame/0/bli_l0_ft.h +++ b/frame/0/bli_l0_ft.h @@ -175,4 +175,3 @@ typedef void (*PASTECH2(ch,opname,tsuf)) \ INSERT_GENTDEFR( zipsc ) - diff --git a/frame/0/bli_l0_oapi.c b/frame/0/bli_l0_oapi.c index 9a54929715..ac62530dbc 100644 --- a/frame/0/bli_l0_oapi.c +++ b/frame/0/bli_l0_oapi.c @@ -64,13 +64,13 @@ void PASTEMAC0(opname) \ bli_obj_scalar_set_dt_buffer( chi, dt_absq_c, &dt_chi, &buf_chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ ( \ - buf_chi, \ - buf_absq \ + buf_chi, \ + buf_absq \ ); \ } @@ -100,14 +100,14 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi, psi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ - conjchi, \ - buf_chi, \ - buf_psi \ + conjchi, \ + buf_chi, \ + buf_psi \ ); \ } @@ -137,13 +137,13 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ - conjchi, \ - buf_chi \ + conjchi, \ + buf_chi \ ); \ } @@ -170,13 +170,13 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi, psi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ - buf_chi, \ - buf_psi \ + buf_chi, \ + buf_psi \ ); \ } @@ -213,14 +213,14 @@ void PASTEMAC0(opname) \ else dt_use = dt_chi; \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_use ); \ \ f \ ( \ - buf_chi, \ - zeta_r, \ - zeta_i \ + buf_chi, \ + zeta_r, \ + zeta_i \ ); \ } @@ -247,14 +247,14 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( zeta_r, zeta_i, chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ ( \ - zeta_r, \ - zeta_i, \ - buf_chi \ + zeta_r, \ + zeta_i, \ + buf_chi \ ); \ } @@ -290,14 +290,14 @@ void PASTEMAC0(opname) \ bli_obj_scalar_set_dt_buffer( chi, dt_zeta_c, &dt_chi, &buf_chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ ( \ - buf_chi, \ - buf_zeta_r, \ - buf_zeta_i \ + buf_chi, \ + buf_zeta_r, \ + buf_zeta_i \ ); \ } @@ -327,14 +327,14 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi, zeta_r, zeta_i ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ ( \ - buf_zeta_i, \ - buf_zeta_r, \ - buf_chi \ + buf_zeta_i, \ + buf_zeta_r, \ + buf_chi \ ); \ } diff --git a/frame/0/bli_l0_oapi.h b/frame/0/bli_l0_oapi.h index d0b05606f8..702bb40eaa 100644 --- a/frame/0/bli_l0_oapi.h +++ b/frame/0/bli_l0_oapi.h @@ -128,9 +128,3 @@ BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ GENPROT( zipsc ) - - - - - - diff --git a/frame/1/bli_l1v.h b/frame/1/bli_l1v.h index c64ed99126..99ceb3a3fe 100644 --- a/frame/1/bli_l1v.h +++ b/frame/1/bli_l1v.h @@ -41,18 +41,22 @@ // Prototype object APIs (expert and non-expert). #include "bli_oapi_ex.h" #include "bli_l1v_oapi.h" +#include "bli_xapi_undef.h" #include "bli_oapi_ba.h" #include "bli_l1v_oapi.h" +#include "bli_xapi_undef.h" // Prototype typed APIs (expert and non-expert). #include "bli_tapi_ex.h" #include "bli_l1v_tapi.h" #include "bli_l1v_ft.h" +#include "bli_xapi_undef.h" #include "bli_tapi_ba.h" #include "bli_l1v_tapi.h" #include "bli_l1v_ft.h" +#include "bli_xapi_undef.h" // Generate function pointer arrays for tapi functions (expert only). #include "bli_l1v_fpa.h" diff --git a/frame/1/bli_l1v_oapi.c b/frame/1/bli_l1v_oapi.c index 19e61bb7aa..201af2e091 100644 --- a/frame/1/bli_l1v_oapi.c +++ b/frame/1/bli_l1v_oapi.c @@ -67,7 +67,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -113,7 +113,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, index ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -174,7 +174,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -232,7 +232,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -282,7 +282,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y, rho ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -349,7 +349,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -394,7 +394,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -445,7 +445,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -490,7 +490,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -545,7 +545,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/1/other/packv/bli_packv_cntl.c b/frame/1/other/packv/bli_packv_cntl.c index 05f1472d08..ce4586f3f9 100644 --- a/frame/1/other/packv/bli_packv_cntl.c +++ b/frame/1/other/packv/bli_packv_cntl.c @@ -36,8 +36,8 @@ cntl_t* bli_packv_cntl_obj_create ( - void* var_func, - void* packv_var_func, + void_fp var_func, + void_fp packv_var_func, bszid_t bmid, pack_t pack_schema, cntl_t* sub_node diff --git a/frame/1/other/packv/bli_packv_cntl.h b/frame/1/other/packv/bli_packv_cntl.h index 87f33524b7..f1ba76a865 100644 --- a/frame/1/other/packv/bli_packv_cntl.h +++ b/frame/1/other/packv/bli_packv_cntl.h @@ -58,8 +58,8 @@ typedef struct packv_params_s packv_params_t; cntl_t* bli_packv_cntl_obj_create ( - void* var_func, - void* packv_var_func, + void_fp var_func, + void_fp packv_var_func, bszid_t bmid, pack_t pack_schema, cntl_t* sub_node diff --git a/frame/1/other/packv/bli_packv_init.c b/frame/1/other/packv/bli_packv_init.c index 31fbda27d5..ba424996f5 100644 --- a/frame/1/other/packv/bli_packv_init.c +++ b/frame/1/other/packv/bli_packv_init.c @@ -117,7 +117,7 @@ siz_t bli_packv_init_pack dim_t dim_a = bli_obj_vector_dim( a ); dim_t bmult = bli_cntx_get_blksz_def_dt( dt, bmult_id, cntx ); - membrk_t* membrk = bli_cntx_membrk( cntx ); + pba_t* pba = bli_cntx_pba( cntx ); #if 0 mem_t* mem_p; @@ -156,9 +156,7 @@ siz_t bli_packv_init_pack { // If the mem_t object of p has not yet been allocated, then acquire // a memory block suitable for a vector. - bli_membrk_acquire_v( membrk, - size_p, - mem_p ); + bli_pba_acquire_v( pba, size_p, mem_p ); } else { @@ -166,11 +164,9 @@ siz_t bli_packv_init_pack // re-acquire the memory so there is sufficient space. if ( bli_mem_size( mem_p ) < size_p ) { - bli_membrk_release( mem_p ); + bli_pba_release( mem_p ); - bli_membrk_acquire_v( membrk, - size_p, - mem_p ); + bli_pba_acquire_v( pba, size_p, mem_p ); } } diff --git a/frame/1d/bli_l1d.h b/frame/1d/bli_l1d.h index c0eeb133fe..aa42eeb44d 100644 --- a/frame/1d/bli_l1d.h +++ b/frame/1d/bli_l1d.h @@ -37,18 +37,22 @@ // Prototype object APIs (expert and non-expert). #include "bli_oapi_ex.h" #include "bli_l1d_oapi.h" +#include "bli_xapi_undef.h" #include "bli_oapi_ba.h" #include "bli_l1d_oapi.h" +#include "bli_xapi_undef.h" // Prototype typed APIs (expert and non-expert). #include "bli_tapi_ex.h" #include "bli_l1d_tapi.h" #include "bli_l1d_ft.h" +#include "bli_xapi_undef.h" #include "bli_tapi_ba.h" #include "bli_l1d_tapi.h" #include "bli_l1d_ft.h" +#include "bli_xapi_undef.h" // Generate function pointer arrays for tapi functions (expert only). #include "bli_l1d_fpa.h" diff --git a/frame/1d/bli_l1d_oapi.c b/frame/1d/bli_l1d_oapi.c index 1a8b8f124b..15e68cf50f 100644 --- a/frame/1d/bli_l1d_oapi.c +++ b/frame/1d/bli_l1d_oapi.c @@ -72,7 +72,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -138,7 +138,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -187,7 +187,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -243,7 +243,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -293,7 +293,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( alpha, x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -349,7 +349,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -411,7 +411,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/1f/bli_l1f.h b/frame/1f/bli_l1f.h index 370b3c9a7c..43676ec4ef 100644 --- a/frame/1f/bli_l1f.h +++ b/frame/1f/bli_l1f.h @@ -40,18 +40,22 @@ // Prototype object APIs (expert and non-expert). #include "bli_oapi_ex.h" #include "bli_l1f_oapi.h" +#include "bli_xapi_undef.h" #include "bli_oapi_ba.h" #include "bli_l1f_oapi.h" +#include "bli_xapi_undef.h" // Prototype typed APIs (expert and non-expert). #include "bli_tapi_ex.h" #include "bli_l1f_tapi.h" #include "bli_l1f_ft.h" +#include "bli_xapi_undef.h" #include "bli_tapi_ba.h" #include "bli_l1f_tapi.h" #include "bli_l1f_ft.h" +#include "bli_xapi_undef.h" // Generate function pointer arrays for tapi functions (expert only). #include "bli_l1f_fpa.h" diff --git a/frame/1f/bli_l1f_oapi.c b/frame/1f/bli_l1f_oapi.c index d1e7f0dbe4..db8fdfb68c 100644 --- a/frame/1f/bli_l1f_oapi.c +++ b/frame/1f/bli_l1f_oapi.c @@ -88,7 +88,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alphay = bli_obj_buffer_for_1x1( dt, &alphay_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -157,7 +157,7 @@ void PASTEMAC(opname,EX_SUF) \ if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -225,7 +225,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -310,7 +310,7 @@ void PASTEMAC(opname,EX_SUF) \ if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -390,7 +390,7 @@ void PASTEMAC(opname,EX_SUF) \ if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/1m/bli_l1m.h b/frame/1m/bli_l1m.h index 1e782cc682..925b9b376f 100644 --- a/frame/1m/bli_l1m.h +++ b/frame/1m/bli_l1m.h @@ -43,18 +43,22 @@ // Prototype object APIs (expert and non-expert). #include "bli_oapi_ex.h" #include "bli_l1m_oapi.h" +#include "bli_xapi_undef.h" #include "bli_oapi_ba.h" #include "bli_l1m_oapi.h" +#include "bli_xapi_undef.h" // Prototype typed APIs (expert and non-expert). #include "bli_tapi_ex.h" #include "bli_l1m_tapi.h" #include "bli_l1m_ft.h" +#include "bli_xapi_undef.h" #include "bli_tapi_ba.h" #include "bli_l1m_tapi.h" #include "bli_l1m_ft.h" +#include "bli_xapi_undef.h" // Generate function pointer arrays for tapi functions (expert only). #include "bli_l1m_fpa.h" diff --git a/frame/1m/bli_l1m_ft.h b/frame/1m/bli_l1m_ft.h index 152915df4c..af6c384e53 100644 --- a/frame/1m/bli_l1m_ft.h +++ b/frame/1m/bli_l1m_ft.h @@ -57,25 +57,6 @@ typedef void (*PASTECH3(ch,opname,EX_SUF,tsuf)) \ INSERT_GENTDEF( addm ) INSERT_GENTDEF( subm ) - -// copym - -#undef GENTDEF -#define GENTDEF( ctype, ch, opname, tsuf ) \ -\ -typedef void (*PASTECH3(ch,opname,EX_SUF,tsuf)) \ - ( \ - doff_t diagoffx, \ - diag_t diagx, \ - uplo_t uplox, \ - trans_t transx, \ - dim_t m, \ - dim_t n, \ - ctype* x, inc_t rs_x, inc_t cs_x, \ - ctype* y, inc_t rs_y, inc_t cs_y \ - BLIS_TAPI_EX_PARAMS \ - ); - INSERT_GENTDEF( copym ) // axpym diff --git a/frame/1m/bli_l1m_ft_ker.h b/frame/1m/bli_l1m_ft_ker.h index cf1c1088c3..2e813cf4a6 100644 --- a/frame/1m/bli_l1m_ft_ker.h +++ b/frame/1m/bli_l1m_ft_ker.h @@ -50,21 +50,23 @@ typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ ( \ struc_t strucc, \ - doff_t diagoffc, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ inc_t is_p, \ - cntx_t* cntx \ + cntx_t* cntx, \ + void* params \ ); INSERT_GENTDEF( packm ) @@ -72,11 +74,6 @@ INSERT_GENTDEF( packm ) // NOTE: the following macros generate packm kernel function type definitions // that are "ctyped" and void-typed, for each of the floating-point datatypes. -// However, we will only make use of the void-typed definitions because the -// functions such as bli_?packm_cxk() (currently) use arrays of function -// pointers to store and access the function pointers for various unrolling -// (register blocksize) values, and therefore they must all be of the same -// type (hence the use of void* for kappa, a, and p). // packm_ker @@ -86,6 +83,7 @@ INSERT_GENTDEF( packm ) typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ @@ -114,28 +112,6 @@ typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ INSERT_GENTDEF( unpackm_cxk ) -// packm_3mis_ker -// packm_4mi_ker - -#undef GENTDEF -#define GENTDEF( ctype, ch, opname, tsuf ) \ -\ -typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - -INSERT_GENTDEF( packm_cxk_3mis ) -INSERT_GENTDEF( packm_cxk_4mi ) - -// packm_rih_ker // packm_1er_ker #undef GENTDEF @@ -154,12 +130,8 @@ typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ cntx_t* restrict cntx \ ); -INSERT_GENTDEF( packm_cxk_rih ) INSERT_GENTDEF( packm_cxk_1er ) - - - #endif diff --git a/frame/1m/bli_l1m_ker.h b/frame/1m/bli_l1m_ker.h index f79a292d33..76d51af2b0 100644 --- a/frame/1m/bli_l1m_ker.h +++ b/frame/1m/bli_l1m_ker.h @@ -74,51 +74,6 @@ INSERT_GENTPROT_BASIC0( unpackm_14xk_ker_name ) INSERT_GENTPROT_BASIC0( unpackm_16xk_ker_name ) -// 3mis packm kernels - -#undef GENTPROT -#define GENTPROT PACKM_3MIS_KER_PROT - -INSERT_GENTPROT_BASIC0( packm_2xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_4xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_6xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_8xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_10xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_12xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_14xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_16xk_3mis_ker_name ) - - -// 4mi packm kernels - -#undef GENTPROT -#define GENTPROT PACKM_4MI_KER_PROT - -INSERT_GENTPROT_BASIC0( packm_2xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_4xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_6xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_8xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_10xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_12xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_14xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_16xk_4mi_ker_name ) - - -// rih packm kernels - -#undef GENTPROT -#define GENTPROT PACKM_RIH_KER_PROT - -INSERT_GENTPROT_BASIC0( packm_2xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_4xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_6xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_8xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_10xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_12xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_14xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_16xk_rih_ker_name ) - - // 1e/1r packm kernels #undef GENTPROT diff --git a/frame/1m/bli_l1m_ker_prot.h b/frame/1m/bli_l1m_ker_prot.h index ada520b566..02d3296220 100644 --- a/frame/1m/bli_l1m_ker_prot.h +++ b/frame/1m/bli_l1m_ker_prot.h @@ -44,12 +44,13 @@ void PASTEMAC(ch,varname) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ); @@ -62,61 +63,9 @@ void PASTEMAC(ch,varname) \ ( \ conj_t conja, \ dim_t n, \ - void* restrict kappa, \ - void* restrict p, inc_t ldp, \ - void* restrict a, inc_t inca, inc_t lda, \ - cntx_t* restrict cntx \ - ); - - -// 3mis packm kernels - -#define PACKM_3MIS_KER_PROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - - -// 4mi packm kernels - -#define PACKM_4MI_KER_PROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - - -// rih packm kernels - -#define PACKM_RIH_KER_PROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict p, inc_t ldp, \ + ctype* restrict a, inc_t inca, inc_t lda, \ cntx_t* restrict cntx \ ); @@ -132,9 +81,9 @@ void PASTEMAC(ch,varname) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ); diff --git a/frame/1m/bli_l1m_oapi.c b/frame/1m/bli_l1m_oapi.c index 4bb0de7849..840b058d4a 100644 --- a/frame/1m/bli_l1m_oapi.c +++ b/frame/1m/bli_l1m_oapi.c @@ -73,22 +73,22 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ - ( \ - diagoffx, \ - diagx, \ - uplox, \ - transx, \ - m, \ - n, \ - buf_x, rs_x, cs_x, \ - buf_y, rs_y, cs_y, \ - cntx, \ - rntm \ + ( \ + diagoffx, \ + diagx, \ + uplox, \ + transx, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + buf_y, rs_y, cs_y, \ + cntx, \ + rntm \ ); \ } @@ -141,23 +141,23 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ - ( \ - diagoffx, \ - diagx, \ - uplox, \ - transx, \ - m, \ - n, \ - buf_alpha, \ - buf_x, rs_x, cs_x, \ - buf_y, rs_y, cs_y, \ - cntx, \ - rntm \ + ( \ + diagoffx, \ + diagx, \ + uplox, \ + transx, \ + m, \ + n, \ + buf_alpha, \ + buf_x, rs_x, cs_x, \ + buf_y, rs_y, cs_y, \ + cntx, \ + rntm \ ); \ } @@ -218,22 +218,22 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_internal_scalar_buffer( &x_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ - ( \ - BLIS_NO_CONJUGATE, /* internal conjugation applied during copy-cast. */ \ - diagoffx, \ - diagx, \ - uplox, \ - m, \ - n, \ - buf_alpha, \ - buf_x, rs_x, cs_x, \ - cntx, \ - rntm \ + ( \ + BLIS_NO_CONJUGATE, /* internal conjugation applied during copy-cast. */ \ + diagoffx, \ + diagx, \ + uplox, \ + m, \ + n, \ + buf_alpha, \ + buf_x, rs_x, cs_x, \ + cntx, \ + rntm \ ); \ } @@ -280,22 +280,22 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ - ( \ - BLIS_NO_CONJUGATE, /* internal conjugation applied during copy-cast. */ \ - diagoffx, \ - diagx, \ - uplox, \ - m, \ - n, \ - buf_alpha, \ - buf_x, rs_x, cs_x, \ - cntx, \ - rntm \ + ( \ + BLIS_NO_CONJUGATE, /* internal conjugation applied during copy-cast. */ \ + diagoffx, \ + diagx, \ + uplox, \ + m, \ + n, \ + buf_alpha, \ + buf_x, rs_x, cs_x, \ + cntx, \ + rntm \ ); \ } @@ -349,23 +349,23 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ - ( \ - diagoffx, \ - diagx, \ - uplox, \ - transx, \ - m, \ - n, \ - buf_x, rs_x, cs_x, \ - buf_beta, \ - buf_y, rs_y, cs_y, \ - cntx, \ - rntm \ + ( \ + diagoffx, \ + diagx, \ + uplox, \ + transx, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + buf_beta, \ + buf_y, rs_y, cs_y, \ + cntx, \ + rntm \ ); \ } @@ -414,23 +414,23 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dty, &beta_local ); \ \ /* Query a (multi) type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp2)( dtx, dty ); \ \ f \ ( \ - diagoffx, \ - diagx, \ - uplox, \ - transx, \ - m, \ - n, \ - buf_x, rs_x, cs_x, \ - buf_beta, \ - buf_y, rs_y, cs_y, \ - cntx, \ - rntm \ + diagoffx, \ + diagx, \ + uplox, \ + transx, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + buf_beta, \ + buf_y, rs_y, cs_y, \ + cntx, \ + rntm \ ); \ } diff --git a/frame/1m/bli_l1m_oft_var.h b/frame/1m/bli_l1m_oft_var.h index 15e9dae6f5..0b60d4e2f6 100644 --- a/frame/1m/bli_l1m_oft_var.h +++ b/frame/1m/bli_l1m_oft_var.h @@ -48,6 +48,7 @@ typedef void (*PASTECH(opname,_var_oft)) \ obj_t* a, \ obj_t* p, \ cntx_t* cntx, \ + rntm_t* rntm, \ cntl_t* cntl, \ thrinfo_t* thread \ ); diff --git a/frame/1m/bli_l1m_unb_var1.c b/frame/1m/bli_l1m_unb_var1.c index cb6098e3f0..f2ce3c8d7e 100644 --- a/frame/1m/bli_l1m_unb_var1.c +++ b/frame/1m/bli_l1m_unb_var1.c @@ -57,15 +57,12 @@ void PASTEMAC(ch,opname) \ { \ const num_t dt = PASTEMAC(ch,type); \ \ - ctype* x1; \ - ctype* y1; \ uplo_t uplox_eff; \ conj_t conjx; \ dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ + dim_t n_elem_max; \ inc_t ldx, incx; \ inc_t ldy, incy; \ - dim_t j, i; \ dim_t ij0, n_shift; \ \ /* Set various loop parameters. */ \ @@ -88,62 +85,65 @@ void PASTEMAC(ch,opname) \ /* Handle dense and upper/lower storage cases separately. */ \ if ( bli_is_dense( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ + ctype* x1 = x + (j )*ldx + (0 )*incx; \ + ctype* y1 = y + (j )*ldy + (0 )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ } \ } \ else \ { \ if ( bli_is_upper( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ + const dim_t n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ \ - x1 = x + (ij0+j )*ldx + (0 )*incx; \ - y1 = y + (ij0+j )*ldy + (0 )*incy; \ + ctype* x1 = x + (ij0+j )*ldx + (0 )*incx; \ + ctype* y1 = y + (ij0+j )*ldy + (0 )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ } \ } \ else if ( bli_is_lower( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ + const dim_t offi = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + const dim_t n_elem = n_elem_max - offi; \ \ - x1 = x + (j )*ldx + (ij0+i )*incx; \ - y1 = y + (j )*ldy + (ij0+i )*incy; \ + ctype* x1 = x + (j )*ldx + (ij0+offi )*incx; \ + ctype* y1 = y + (j )*ldy + (ij0+offi )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ } \ } \ } \ @@ -174,15 +174,12 @@ void PASTEMAC(ch,opname) \ { \ const num_t dt = PASTEMAC(ch,type); \ \ - ctype* x1; \ - ctype* y1; \ uplo_t uplox_eff; \ conj_t conjx; \ dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ + dim_t n_elem_max; \ inc_t ldx, incx; \ inc_t ldy, incy; \ - dim_t j, i; \ dim_t ij0, n_shift; \ \ /* Set various loop parameters. */ \ @@ -205,65 +202,68 @@ void PASTEMAC(ch,opname) \ /* Handle dense and upper/lower storage cases separately. */ \ if ( bli_is_dense( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ + ctype* x1 = x + (j )*ldx + (0 )*incx; \ + ctype* y1 = y + (j )*ldy + (0 )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - alpha, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + alpha, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ } \ } \ else \ { \ if ( bli_is_upper( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ + const dim_t n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ \ - x1 = x + (ij0+j )*ldx + (0 )*incx; \ - y1 = y + (ij0+j )*ldy + (0 )*incy; \ + ctype* x1 = x + (ij0+j )*ldx + (0 )*incx; \ + ctype* y1 = y + (ij0+j )*ldy + (0 )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - alpha, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + alpha, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ } \ } \ else if ( bli_is_lower( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ + const dim_t offi = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + const dim_t n_elem = n_elem_max - offi; \ \ - x1 = x + (j )*ldx + (ij0+i )*incx; \ - y1 = y + (j )*ldy + (ij0+i )*incy; \ + ctype* x1 = x + (j )*ldx + (ij0+offi )*incx; \ + ctype* y1 = y + (j )*ldy + (ij0+offi )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - alpha, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + alpha, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ } \ } \ } \ @@ -292,12 +292,10 @@ void PASTEMAC(ch,opname) \ { \ const num_t dt = PASTEMAC(ch,type); \ \ - ctype* x1; \ uplo_t uplox_eff; \ dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ + dim_t n_elem_max; \ inc_t ldx, incx; \ - dim_t j, i; \ dim_t ij0, n_shift; \ \ /* Set various loop parameters. */ \ @@ -317,59 +315,62 @@ void PASTEMAC(ch,opname) \ /* Handle dense and upper/lower storage cases separately. */ \ if ( bli_is_dense( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - x1 = x + (j )*ldx + (0 )*incx; \ + ctype* x1 = x + (j )*ldx + (0 )*incx; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjalpha, \ - n_elem, \ - alpha, \ - x1, incx, \ - cntx \ - ); \ + f \ + ( \ + conjalpha, \ + n_elem, \ + alpha, \ + x1, incx, \ + cntx \ + ); \ } \ } \ else \ { \ if ( bli_is_upper( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ + const dim_t n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ \ - x1 = x + (ij0+j )*ldx + (0 )*incx; \ + ctype* x1 = x + (ij0+j )*ldx + (0 )*incx; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjalpha, \ - n_elem, \ - alpha, \ - x1, incx, \ - cntx \ - ); \ + f \ + ( \ + conjalpha, \ + n_elem, \ + alpha, \ + x1, incx, \ + cntx \ + ); \ } \ } \ else if ( bli_is_lower( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ + const dim_t offi = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + const dim_t n_elem = n_elem_max - offi; \ \ - x1 = x + (j )*ldx + (ij0+i )*incx; \ + ctype* x1 = x + (j )*ldx + (ij0+offi )*incx; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjalpha, \ - n_elem, \ - alpha, \ - x1, incx, \ - cntx \ - ); \ + f \ + ( \ + conjalpha, \ + n_elem, \ + alpha, \ + x1, incx, \ + cntx \ + ); \ } \ } \ } \ @@ -399,15 +400,12 @@ void PASTEMAC(ch,opname) \ { \ const num_t dt = PASTEMAC(ch,type); \ \ - ctype* x1; \ - ctype* y1; \ uplo_t uplox_eff; \ conj_t conjx; \ dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ + dim_t n_elem_max; \ inc_t ldx, incx; \ inc_t ldy, incy; \ - dim_t j, i; \ dim_t ij0, n_shift; \ \ /* Set various loop parameters. */ \ @@ -430,65 +428,68 @@ void PASTEMAC(ch,opname) \ /* Handle dense and upper/lower storage cases separately. */ \ if ( bli_is_dense( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ + ctype* x1 = x + (j )*ldx + (0 )*incx; \ + ctype* y1 = y + (j )*ldy + (0 )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - x1, incx, \ - beta, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + x1, incx, \ + beta, \ + y1, incy, \ + cntx \ + ); \ } \ } \ else \ { \ if ( bli_is_upper( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ + const dim_t n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ \ - x1 = x + (ij0+j )*ldx + (0 )*incx; \ - y1 = y + (ij0+j )*ldy + (0 )*incy; \ + ctype* x1 = x + (ij0+j )*ldx + (0 )*incx; \ + ctype* y1 = y + (ij0+j )*ldy + (0 )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - x1, incx, \ - beta, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + x1, incx, \ + beta, \ + y1, incy, \ + cntx \ + ); \ } \ } \ else if ( bli_is_lower( uplox_eff ) ) \ { \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ + const dim_t offi = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + const dim_t n_elem = n_elem_max - offi; \ \ - x1 = x + (j )*ldx + (ij0+i )*incx; \ - y1 = y + (j )*ldy + (ij0+i )*incy; \ + ctype* x1 = x + (j )*ldx + (ij0+offi )*incx; \ + ctype* y1 = y + (j )*ldy + (ij0+offi )*incy; \ \ /* Invoke the kernel with the appropriate parameters. */ \ - f( \ - conjx, \ - n_elem, \ - x1, incx, \ - beta, \ - y1, incy, \ - cntx \ - ); \ + f \ + ( \ + conjx, \ + n_elem, \ + x1, incx, \ + beta, \ + y1, incy, \ + cntx \ + ); \ } \ } \ } \ @@ -515,15 +516,12 @@ void PASTEMAC2(chx,chy,opname) \ rntm_t* rntm \ ) \ { \ - ctype_x* restrict x1; \ - ctype_y* restrict y1; \ - uplo_t uplox_eff; \ - dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ - inc_t ldx, incx; \ - inc_t ldy, incy; \ - dim_t j, i; \ - dim_t ij0, n_shift; \ + uplo_t uplox_eff; \ + dim_t n_iter; \ + dim_t n_elem_max; \ + inc_t ldx, incx; \ + inc_t ldy, incy; \ + dim_t ij0, n_shift; \ \ /* Set various loop parameters. */ \ bli_set_dims_incs_uplo_2m \ @@ -542,35 +540,32 @@ void PASTEMAC2(chx,chy,opname) \ { \ if ( incx == 1 && incy == 1 ) \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ + ctype_x* restrict x1 = x + (j )*ldx + (0 )*incx; \ + ctype_y* restrict y1 = y + (j )*ldy + (0 )*incy; \ \ - ctype_x* restrict chi1 = x1; \ - ctype_y* restrict psi1 = y1; \ -\ - for ( i = 0; i < n_elem; ++i ) \ + for ( dim_t i = 0; i < n_elem; ++i ) \ { \ - PASTEMAC2(chx,chy,adds)( chi1[i], psi1[i] ); \ + PASTEMAC2(chx,chy,adds)( x1[i], y1[i] ); \ } \ } \ } \ else \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ + ctype_x* restrict x1 = x + (j )*ldx + (0 )*incx; \ + ctype_y* restrict y1 = y + (j )*ldy + (0 )*incy; \ \ ctype_x* restrict chi1 = x1; \ ctype_y* restrict psi1 = y1; \ \ - for ( i = 0; i < n_elem; ++i ) \ + for ( dim_t i = 0; i < n_elem; ++i ) \ { \ PASTEMAC2(chx,chy,adds)( *chi1, *psi1 ); \ \ @@ -584,35 +579,32 @@ void PASTEMAC2(chx,chy,opname) \ { \ if ( incx == 1 && incy == 1 ) \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ -\ - ctype_x* restrict chi1 = x1; \ - ctype_y* restrict psi1 = y1; \ + ctype_x* restrict x1 = x + (j )*ldx + (0 )*incx; \ + ctype_y* restrict y1 = y + (j )*ldy + (0 )*incy; \ \ - for ( i = 0; i < n_elem; ++i ) \ + for ( dim_t i = 0; i < n_elem; ++i ) \ { \ - PASTEMAC3(chx,chy,chy,xpbys)( chi1[i], *beta, psi1[i] ); \ + PASTEMAC3(chx,chy,chy,xpbys)( x1[i], *beta, y1[i] ); \ } \ } \ } \ else \ { \ - n_elem = n_elem_max; \ + const dim_t n_elem = n_elem_max; \ \ - for ( j = 0; j < n_iter; ++j ) \ + for ( dim_t j = 0; j < n_iter; ++j ) \ { \ - x1 = x + (j )*ldx + (0 )*incx; \ - y1 = y + (j )*ldy + (0 )*incy; \ + ctype_x* restrict x1 = x + (j )*ldx + (0 )*incx; \ + ctype_y* restrict y1 = y + (j )*ldy + (0 )*incy; \ \ ctype_x* restrict chi1 = x1; \ ctype_y* restrict psi1 = y1; \ \ - for ( i = 0; i < n_elem; ++i ) \ + for ( dim_t i = 0; i < n_elem; ++i ) \ { \ PASTEMAC3(chx,chy,chy,xpbys)( *chi1, *beta, *psi1 ); \ \ diff --git a/frame/1m/bli_l1m_unb_var1.h b/frame/1m/bli_l1m_unb_var1.h index 81be9fe808..0364d4b7cd 100644 --- a/frame/1m/bli_l1m_unb_var1.h +++ b/frame/1m/bli_l1m_unb_var1.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,_unb_var1) \ +void PASTEMAC2(ch,opname,_unb_var1) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -62,7 +62,7 @@ INSERT_GENTPROT_BASIC0( subm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,_unb_var1) \ +void PASTEMAC2(ch,opname,_unb_var1) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -84,7 +84,7 @@ INSERT_GENTPROT_BASIC0( scal2m ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,_unb_var1) \ +void PASTEMAC2(ch,opname,_unb_var1) \ ( \ conj_t conjalpha, \ doff_t diagoffx, \ @@ -105,7 +105,7 @@ INSERT_GENTPROT_BASIC0( setm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,_unb_var1) \ +void PASTEMAC2(ch,opname,_unb_var1) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -126,7 +126,7 @@ INSERT_GENTPROT_BASIC0( xpbym ) #undef GENTPROT2 #define GENTPROT2( ctype_x, ctype_y, chx, chy, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC3(chx,chy,opname,_unb_var1) \ +void PASTEMAC3(chx,chy,opname,_unb_var1) \ ( \ doff_t diagoffx, \ diag_t diagx, \ diff --git a/frame/1m/other/bli_scalm_cntl.c b/frame/1m/other/bli_scalm_cntl.c index afff3c1646..b815dc530d 100644 --- a/frame/1m/other/bli_scalm_cntl.c +++ b/frame/1m/other/bli_scalm_cntl.c @@ -36,7 +36,7 @@ cntl_t* bli_scalm_cntl_create_node ( - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/frame/1m/other/bli_scalm_cntl.h b/frame/1m/other/bli_scalm_cntl.h index 0d589f2073..32a02f5dab 100644 --- a/frame/1m/other/bli_scalm_cntl.h +++ b/frame/1m/other/bli_scalm_cntl.h @@ -35,6 +35,6 @@ cntl_t* bli_scalm_cntl_create_node ( - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/frame/1m/packm/bli_packm.h b/frame/1m/packm/bli_packm.h index fbc02e3926..88657a7128 100644 --- a/frame/1m/packm/bli_packm.h +++ b/frame/1m/packm/bli_packm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,29 +33,25 @@ */ +#include "bli_packm_alloc.h" #include "bli_packm_cntl.h" #include "bli_packm_check.h" #include "bli_packm_init.h" #include "bli_packm_int.h" +#include "bli_packm_scalar.h" #include "bli_packm_part.h" -#include "bli_packm_var.h" - #include "bli_packm_struc_cxk.h" -#include "bli_packm_struc_cxk_4mi.h" -#include "bli_packm_struc_cxk_3mis.h" -#include "bli_packm_struc_cxk_rih.h" #include "bli_packm_struc_cxk_1er.h" #include "bli_packm_cxk.h" -#include "bli_packm_cxk_4mi.h" -#include "bli_packm_cxk_3mis.h" -#include "bli_packm_cxk_rih.h" #include "bli_packm_cxk_1er.h" // Mixed datatype support. #ifdef BLIS_ENABLE_GEMM_MD -#include "bli_packm_md.h" +#include "bli_packm_struc_cxk_md.h" #endif +#include "bli_packm_blk_var1.h" + diff --git a/frame/1m/packm/bli_packm_alloc.c b/frame/1m/packm/bli_packm_alloc.c new file mode 100644 index 0000000000..b12a93ddc0 --- /dev/null +++ b/frame/1m/packm/bli_packm_alloc.c @@ -0,0 +1,119 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2016, Hewlett Packard Enterprise Development LP + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void* bli_packm_alloc + ( + siz_t size_needed, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // Query the pack buffer type from the control tree node. + packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); + + return bli_packm_alloc_ex + ( + size_needed, + pack_buf_type, + rntm, + cntl, + thread + ); +} + +void* bli_packm_alloc_ex + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // Query the address of the mem_t entry within the control tree node. + mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl ); + + mem_t* local_mem_p; + mem_t local_mem_s; + + siz_t cntl_mem_size = 0; + + if ( bli_mem_is_alloc( cntl_mem_p ) ) + cntl_mem_size = bli_mem_size( cntl_mem_p ); + + if ( cntl_mem_size < size_needed ) + { + if ( bli_thread_am_ochief( thread ) ) + { + // The chief thread releases the existing block associated with + // the mem_t entry in the control tree, and then re-acquires a + // new block, saving the associated mem_t entry to local_mem_s. + if ( bli_mem_is_alloc( cntl_mem_p ) ) + { + bli_pba_release + ( + rntm, + cntl_mem_p + ); + } + + bli_pba_acquire_m + ( + rntm, + size_needed, + pack_buf_type, + &local_mem_s + ); + } + + // Broadcast the address of the chief thread's local mem_t entry to + // all threads. + local_mem_p = bli_thread_broadcast( thread, &local_mem_s ); + + // Save the chief thread's local mem_t entry to the mem_t field in + // this thread's control tree node. + *cntl_mem_p = *local_mem_p; + + // Barrier so that the master thread doesn't return from the function + // before we are done reading. + bli_thread_barrier( thread ); + } + + return bli_mem_buffer( cntl_mem_p ); +} + diff --git a/frame/3/herk/bli_herk_front.h b/frame/1m/packm/bli_packm_alloc.h similarity index 83% rename from frame/3/herk/bli_herk_front.h rename to frame/1m/packm/bli_packm_alloc.h index 44778a450a..5a5cf126b1 100644 --- a/frame/3/herk/bli_herk_front.h +++ b/frame/1m/packm/bli_packm_alloc.h @@ -32,13 +32,20 @@ */ -void bli_herk_front +BLIS_EXPORT_BLIS void* bli_packm_alloc ( - obj_t* alpha, - obj_t* a, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl + siz_t size_needed, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread ); + +BLIS_EXPORT_BLIS void* bli_packm_alloc_ex + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); + diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index 3f753a914d..edeeae2b98 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,35 +35,6 @@ #include "blis.h" -#define FUNCPTR_T packm_fp - -typedef void (*FUNCPTR_T) - ( - struc_t strucc, - doff_t diagoffc, - diag_t diagc, - uplo_t uploc, - trans_t transc, - pack_t schema, - bool_t invdiag, - bool_t revifup, - bool_t reviflo, - dim_t m, - dim_t n, - dim_t m_max, - dim_t n_max, - void* kappa, - void* c, inc_t rs_c, inc_t cs_c, - void* p, inc_t rs_p, inc_t cs_p, - inc_t is_p, - dim_t pd_p, inc_t ps_p, - void* packm_ker, - cntx_t* cntx, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,packm_blk_var1); - static func_t packm_struc_cxk_kers[BLIS_NUM_PACK_SCHEMA_TYPES] = { @@ -71,689 +42,273 @@ static func_t packm_struc_cxk_kers[BLIS_NUM_PACK_SCHEMA_TYPES] = // 0000 row/col panels { { bli_spackm_struc_cxk, bli_cpackm_struc_cxk, bli_dpackm_struc_cxk, bli_zpackm_struc_cxk, } }, -// 0001 row/col panels: 4m interleaved - { { NULL, bli_cpackm_struc_cxk_4mi, - NULL, bli_zpackm_struc_cxk_4mi, } }, -// 0010 row/col panels: 3m interleaved - { { NULL, bli_cpackm_struc_cxk_3mis, - NULL, bli_zpackm_struc_cxk_3mis, } }, -// 0011 row/col panels: 4m separated (NOT IMPLEMENTED) - { { NULL, NULL, - NULL, NULL, } }, -// 0100 row/col panels: 3m separated - { { NULL, bli_cpackm_struc_cxk_3mis, - NULL, bli_zpackm_struc_cxk_3mis, } }, -// 0101 row/col panels: real only - { { NULL, bli_cpackm_struc_cxk_rih, - NULL, bli_zpackm_struc_cxk_rih, } }, -// 0110 row/col panels: imaginary only - { { NULL, bli_cpackm_struc_cxk_rih, - NULL, bli_zpackm_struc_cxk_rih, } }, -// 0111 row/col panels: real+imaginary only - { { NULL, bli_cpackm_struc_cxk_rih, - NULL, bli_zpackm_struc_cxk_rih, } }, -// 1000 row/col panels: 1m-expanded (1e) +// 0001 row/col panels: 1m-expanded (1e) { { NULL, bli_cpackm_struc_cxk_1er, NULL, bli_zpackm_struc_cxk_1er, } }, -// 1001 row/col panels: 1m-reordered (1r) +// 0010 row/col panels: 1m-reordered (1r) { { NULL, bli_cpackm_struc_cxk_1er, NULL, bli_zpackm_struc_cxk_1er, } }, }; +static void_fp GENARRAY2_ALL(packm_struc_cxk_md,packm_struc_cxk_md); void bli_packm_blk_var1 ( obj_t* c, obj_t* p, cntx_t* cntx, + rntm_t* rntm, cntl_t* cntl, - thrinfo_t* t + thrinfo_t* thread ) { -#ifdef BLIS_ENABLE_GEMM_MD - // Call a different packm implementation when the storage and target - // datatypes differ. - if ( bli_obj_dt( c ) != bli_obj_target_dt( c ) ) - { - bli_packm_blk_var1_md( c, p, cntx, cntl, t ); + // Extract various fields from the control tree. + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + bool invdiag = bli_cntl_packm_params_does_invert_diag( cntl ); + bool revifup = bli_cntl_packm_params_rev_iter_if_upper( cntl ); + bool reviflo = bli_cntl_packm_params_rev_iter_if_lower( cntl ); + + // Every thread initializes p and determines the size of memory + // block needed (which gets embedded into the otherwise "blank" mem_t + // entry in the control tree node). Return early if no packing is required. + if ( !bli_packm_init( c, p, cntx, rntm, cntl, thread ) ) return; - } -#endif - num_t dt_p = bli_obj_dt( p ); + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_packm_int_check( c, p, cntx ); - struc_t strucc = bli_obj_struc( c ); - doff_t diagoffc = bli_obj_diag_offset( c ); - diag_t diagc = bli_obj_diag( c ); - uplo_t uploc = bli_obj_uplo( c ); - trans_t transc = bli_obj_conjtrans_status( c ); - pack_t schema = bli_obj_pack_schema( p ); - bool_t invdiag = bli_obj_has_inverted_diag( p ); - bool_t revifup = bli_obj_is_pack_rev_if_upper( p ); - bool_t reviflo = bli_obj_is_pack_rev_if_lower( p ); + num_t dt_c = bli_obj_dt( c ); + dim_t dt_c_size = bli_dt_size( dt_c ); - dim_t m_p = bli_obj_length( p ); - dim_t n_p = bli_obj_width( p ); - dim_t m_max_p = bli_obj_padded_length( p ); - dim_t n_max_p = bli_obj_padded_width( p ); + num_t dt_p = bli_obj_dt( p ); + dim_t dt_p_size = bli_dt_size( dt_p ); - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); + struc_t strucc = bli_obj_struc( c ); + doff_t diagoffc = bli_obj_diag_offset( c ); + diag_t diagc = bli_obj_diag( c ); + uplo_t uploc = bli_obj_uplo( c ); + conj_t conjc = bli_obj_conj_status( c ); - void* buf_p = bli_obj_buffer_at_off( p ); - inc_t rs_p = bli_obj_row_stride( p ); - inc_t cs_p = bli_obj_col_stride( p ); - inc_t is_p = bli_obj_imag_stride( p ); - dim_t pd_p = bli_obj_panel_dim( p ); - inc_t ps_p = bli_obj_panel_stride( p ); + dim_t iter_dim = bli_obj_length( p ); + dim_t panel_len_full = bli_obj_width( p ); + dim_t panel_len_max = bli_obj_padded_width( p ); - obj_t kappa; - void* buf_kappa; + char* c_cast = bli_obj_buffer_at_off( c ); + inc_t incc = bli_obj_row_stride( c ); + inc_t ldc = bli_obj_col_stride( c ); + dim_t panel_dim_off = bli_obj_row_off( c ); + dim_t panel_len_off = bli_obj_col_off( c ); - func_t* packm_kers; - void* packm_ker; + char* p_cast = bli_obj_buffer( p ); + inc_t ldp = bli_obj_col_stride( p ); + inc_t is_p = bli_obj_imag_stride( p ); + dim_t panel_dim_max = bli_obj_panel_dim( p ); + inc_t ps_p = bli_obj_panel_stride( p ); - FUNCPTR_T f; + doff_t diagoffc_inc = ( doff_t )panel_dim_max; + obj_t kappa_local; + char* kappa_cast = bli_packm_scalar( &kappa_local, p ); - // Treatment of kappa (ie: packing during scaling) depends on - // whether we are executing an induced method. - if ( bli_is_nat_packed( schema ) ) - { - // This branch is for native execution, where we assume that - // the micro-kernel will always apply the alpha scalar of the - // higher-level operation. Thus, we use BLIS_ONE for kappa so - // that the underlying packm implementation does not perform - // any scaling during packing. - buf_kappa = bli_obj_buffer_for_const( dt_p, &BLIS_ONE ); - } - else // if ( bli_is_ind_packed( schema ) ) - { - obj_t* kappa_p; - - // The value for kappa we use will depend on whether the scalar - // attached to A has a nonzero imaginary component. If it does, - // then we will apply the scalar during packing to facilitate - // implementing induced complex domain algorithms in terms of - // real domain micro-kernels. (In the aforementioned situation, - // applying a real scalar is easy, but applying a complex one is - // harder, so we avoid the need altogether with the code below.) - if ( bli_obj_scalar_has_nonzero_imag( p ) ) - { - //printf( "applying non-zero imag kappa\n" ); + // we use the default lookup table to determine the right func_t + // for the current schema. + func_t* packm_kers = &packm_struc_cxk_kers[ bli_pack_schema_index( schema ) ]; - // Detach the scalar. - bli_obj_scalar_detach( p, &kappa ); + // Query the datatype-specific function pointer from the func_t object. + packm_ker_vft packm_ker_cast = bli_func_get_dt( dt_p, packm_kers ); - // Reset the attached scalar (to 1.0). - bli_obj_scalar_reset( p ); + // For mixed-precision gemm, select the proper kernel (only dense panels). + if ( dt_c != dt_p ) + { + packm_ker_cast = packm_struc_cxk_md[ dt_c ][ dt_p ]; + } - kappa_p = κ - } - else - { - // If the internal scalar of A has only a real component, then - // we will apply it later (in the micro-kernel), and so we will - // use BLIS_ONE to indicate no scaling during packing. - kappa_p = &BLIS_ONE; - } + // Query the address of the packm params field of the obj_t. The user might + // have set this field in order to specify a custom packm kernel. + packm_blk_var1_params_t* params = bli_obj_pack_params( c ); - // Acquire the buffer to the kappa chosen above. - buf_kappa = bli_obj_buffer_for_1x1( dt_p, kappa_p ); + if ( params && params->ukr_fn[ dt_c ][ dt_p ] ) + { + // Query the user-provided packing kernel from the obj_t. If provided, + // this overrides the kernel determined above. + packm_ker_cast = params->ukr_fn[ dt_c ][ dt_p ]; } + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); -#if 0 - if ( bli_is_4mi_packed( schema ) ) packm_kers = packm_struc_cxk_4mi_kers; - else if ( bli_is_3mi_packed( schema ) || - bli_is_3ms_packed( schema ) ) packm_kers = packm_struc_cxk_3mis_kers; - else if ( bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ) packm_kers = packm_struc_cxk_rih_kers; - else packm_kers = packm_struc_cxk_kers; -#else - // The original idea here was to read the packm_ukr from the context - // if it is non-NULL. The problem is, it requires that we be able to - // assume that the packm_ukr field is initialized to NULL, which it - // currently is not. - - //func_t* cntx_packm_kers = bli_cntx_get_packm_ukr( cntx ); + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ + dim_t ic0, ip0; + doff_t ic_inc, ip_inc; - //if ( bli_func_is_null_dt( dt_c, cntx_packm_kers ) ) + if ( ( revifup && bli_is_upper( uploc ) && bli_is_triangular( strucc ) ) || + ( reviflo && bli_is_lower( uploc ) && bli_is_triangular( strucc ) ) ) { - // If the packm structure-aware kernel func_t in the context is - // NULL (which is the default value after the context is created), - // we use the default lookup table to determine the right func_t - // for the current schema. - const dim_t i = bli_pack_schema_index( schema ); - - packm_kers = &packm_struc_cxk_kers[ i ]; + ic0 = (n_iter - 1) * panel_dim_max; + ic_inc = -panel_dim_max; + ip0 = n_iter - 1; + ip_inc = -1; } -#if 0 - else // cntx's packm func_t overrides + else { - // If the packm structure-aware kernel func_t in the context is - // non-NULL (ie: assumed to be valid), we use that instead. - //packm_kers = bli_cntx_packm_ukrs( cntx ); - packm_kers = cntx_packm_kers; + ic0 = 0; + ic_inc = panel_dim_max; + ip0 = 0; + ip_inc = 1; } -#endif -#endif - // Query the datatype-specific function pointer from the func_t object. - packm_ker = bli_func_get_dt( dt_p, packm_kers ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_p]; - - // Invoke the function. - f( strucc, - diagoffc, - diagc, - uploc, - transc, - schema, - invdiag, - revifup, - reviflo, - m_p, - n_p, - m_max_p, - n_max_p, - buf_kappa, - buf_c, rs_c, cs_c, - buf_p, rs_p, cs_p, - is_p, - pd_p, ps_p, - packm_ker, - cntx, - t ); -} + // Query the number of threads and thread ids from the current thread's + // packm thrinfo_t node. + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + // Determine the thread range and increment using the current thread's + // packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + // will depend on whether slab or round-robin partitioning was requested + // at configure-time. + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, opname, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - trans_t transc, \ - pack_t schema, \ - bool_t invdiag, \ - bool_t revifup, \ - bool_t reviflo, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - dim_t pd_p, inc_t ps_p, \ - void* packm_ker, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ) \ -{ \ - PASTECH2(ch,opname,_ker_ft) packm_ker_cast = packm_ker; \ -\ - ctype* restrict kappa_cast = kappa; \ - ctype* restrict c_cast = c; \ - ctype* restrict p_cast = p; \ - ctype* restrict c_begin; \ - ctype* restrict p_begin; \ -\ - dim_t iter_dim; \ - dim_t n_iter; \ - dim_t it, ic, ip; \ - dim_t ic0, ip0; \ - doff_t ic_inc, ip_inc; \ - doff_t diagoffc_i; \ - doff_t diagoffc_inc; \ - dim_t panel_len_full; \ - dim_t panel_len_i; \ - dim_t panel_len_max; \ - dim_t panel_len_max_i; \ - dim_t panel_dim_i; \ - dim_t panel_dim_max; \ - dim_t panel_off_i; \ - inc_t vs_c; \ - inc_t ldc; \ - inc_t ldp, p_inc; \ - dim_t* m_panel_full; \ - dim_t* n_panel_full; \ - dim_t* m_panel_use; \ - dim_t* n_panel_use; \ - dim_t* m_panel_max; \ - dim_t* n_panel_max; \ - conj_t conjc; \ - bool_t row_stored; \ - bool_t col_stored; \ - inc_t is_p_use; \ - dim_t ss_num; \ - dim_t ss_den; \ -\ - ctype* restrict c_use; \ - ctype* restrict p_use; \ - doff_t diagoffp_i; \ -\ -\ - /* If C is zeros and part of a triangular matrix, then we don't need - to pack it. */ \ - if ( bli_is_zeros( uploc ) && \ - bli_is_triangular( strucc ) ) return; \ -\ - /* Extract the conjugation bit from the transposition argument. */ \ - conjc = bli_extract_conj( transc ); \ -\ - /* If c needs a transposition, induce it so that we can more simply - express the remaining parameters and code. */ \ - if ( bli_does_trans( transc ) ) \ - { \ - bli_swap_incs( &rs_c, &cs_c ); \ - bli_negate_diag_offset( &diagoffc ); \ - bli_toggle_uplo( &uploc ); \ - bli_toggle_trans( &transc ); \ - } \ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ - /* If the row storage flag indicates row storage, then we are packing - to column panels; otherwise, if the strides indicate column storage, - we are packing to row panels. */ \ - if ( row_stored ) \ - { \ - /* Prepare to pack to row-stored column panels. */ \ - iter_dim = n; \ - panel_len_full = m; \ - panel_len_max = m_max; \ - panel_dim_max = pd_p; \ - ldc = rs_c; \ - vs_c = cs_c; \ - diagoffc_inc = -( doff_t )panel_dim_max; \ - ldp = rs_p; \ - m_panel_full = &m; \ - n_panel_full = &panel_dim_i; \ - m_panel_use = &panel_len_i; \ - n_panel_use = &panel_dim_i; \ - m_panel_max = &panel_len_max_i; \ - n_panel_max = &panel_dim_max; \ - } \ - else /* if ( col_stored ) */ \ - { \ - /* Prepare to pack to column-stored row panels. */ \ - iter_dim = m; \ - panel_len_full = n; \ - panel_len_max = n_max; \ - panel_dim_max = pd_p; \ - ldc = cs_c; \ - vs_c = rs_c; \ - diagoffc_inc = ( doff_t )panel_dim_max; \ - ldp = cs_p; \ - m_panel_full = &panel_dim_i; \ - n_panel_full = &n; \ - m_panel_use = &panel_dim_i; \ - n_panel_use = &panel_len_i; \ - m_panel_max = &panel_dim_max; \ - n_panel_max = &panel_len_max_i; \ - } \ -\ - /* Compute the storage stride scaling. Usually this is just 1. However, - in the case of interleaved 3m, we need to scale by 3/2, and in the - cases of real-only, imag-only, or summed-only, we need to scale by - 1/2. In both cases, we are compensating for the fact that pointer - arithmetic occurs in terms of complex elements rather than real - elements. */ \ - if ( bli_is_3mi_packed( schema ) ) { ss_num = 3; ss_den = 2; } \ - else if ( bli_is_3ms_packed( schema ) ) { ss_num = 1; ss_den = 2; } \ - else if ( bli_is_rih_packed( schema ) ) { ss_num = 1; ss_den = 2; } \ - else { ss_num = 1; ss_den = 1; } \ -\ - /* Compute the total number of iterations we'll need. */ \ - n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ -\ - /* Set the initial values and increments for indices related to C and P - based on whether reverse iteration was requested. */ \ - if ( ( revifup && bli_is_upper( uploc ) && bli_is_triangular( strucc ) ) || \ - ( reviflo && bli_is_lower( uploc ) && bli_is_triangular( strucc ) ) ) \ - { \ - ic0 = (n_iter - 1) * panel_dim_max; \ - ic_inc = -panel_dim_max; \ - ip0 = n_iter - 1; \ - ip_inc = -1; \ - } \ - else \ - { \ - ic0 = 0; \ - ic_inc = panel_dim_max; \ - ip0 = 0; \ - ip_inc = 1; \ - } \ -\ - p_begin = p_cast; \ -\ - /* Query the number of threads and thread ids from the current thread's - packm thrinfo_t node. */ \ - const dim_t nt = bli_thread_n_way( thread ); \ - const dim_t tid = bli_thread_work_id( thread ); \ -\ - dim_t it_start, it_end, it_inc; \ -\ - /* Determine the thread range and increment using the current thread's - packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() - will depend on whether slab or round-robin partitioning was requested - at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ -\ - /* Iterate over every logical micropanel in the source matrix. */ \ - for ( ic = ic0, ip = ip0, it = 0; it < n_iter; \ - ic += ic_inc, ip += ip_inc, it += 1 ) \ - { \ - panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); \ -\ - diagoffc_i = diagoffc + (ip )*diagoffc_inc; \ - c_begin = c_cast + (ic )*vs_c; \ -\ - if ( bli_is_triangular( strucc ) && \ - bli_is_unstored_subpart_n( diagoffc_i, uploc, *m_panel_full, *n_panel_full ) ) \ - { \ - /* This case executes if the panel belongs to a triangular - matrix AND is completely unstored (ie: zero). If the panel - is unstored, we do nothing. (Notice that we don't even - increment p_begin.) */ \ -\ - continue; \ - } \ - else if ( bli_is_triangular( strucc ) && \ - bli_intersects_diag_n( diagoffc_i, *m_panel_full, *n_panel_full ) ) \ - { \ - /* This case executes if the panel belongs to a triangular - matrix AND is diagonal-intersecting. Notice that we - cannot bury the following conditional logic into - packm_struc_cxk() because we need to know the value of - panel_len_max_i so we can properly increment p_inc. */ \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc_i < 0 ) || \ - ( row_stored && diagoffc_i > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - panel_off_i = 0; \ - panel_len_i = bli_abs( diagoffc_i ) + panel_dim_i; \ - panel_len_max_i = bli_min( bli_abs( diagoffc_i ) + panel_dim_max, \ - panel_len_max ); \ - diagoffp_i = diagoffc_i; \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - panel_off_i = bli_abs( diagoffc_i ); \ - panel_len_i = panel_len_full - panel_off_i; \ - panel_len_max_i = panel_len_max - panel_off_i; \ - diagoffp_i = 0; \ - } \ -\ - c_use = c_begin + (panel_off_i )*ldc; \ - p_use = p_begin; \ -\ - /* We need to re-compute the imaginary stride as a function of - panel_len_max_i since triangular packed matrices have panels - of varying lengths. NOTE: This imaginary stride value is - only referenced by the packm kernels for induced methods. */ \ - is_p_use = ldp * panel_len_max_i; \ -\ - /* We nudge the imaginary stride up by one if it is odd. */ \ - is_p_use += ( bli_is_odd( is_p_use ) ? 1 : 0 ); \ -\ - /* NOTE: We MUST use round-robin partitioning when packing - micropanels of a triangular matrix. Hermitian/symmetric - and general packing may use slab or round-robin, depending - on which was selected at configure-time. */ \ - if ( bli_packm_my_iter_rr( it, it_start, it_end, tid, nt ) ) \ - { \ - packm_ker_cast( strucc, \ - diagoffp_i, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - *m_panel_use, \ - *n_panel_use, \ - *m_panel_max, \ - *n_panel_max, \ - kappa_cast, \ - c_use, rs_c, cs_c, \ - p_use, rs_p, cs_p, \ - is_p_use, \ - cntx ); \ - } \ -\ - /* NOTE: This value is usually LESS than ps_p because triangular - matrices usually have several micro-panels that are shorter - than a "full" micro-panel. */ \ - p_inc = ( is_p_use * ss_num ) / ss_den; \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* This case executes if the panel belongs to a Hermitian or - symmetric matrix, which includes stored, unstored, and - diagonal-intersecting panels. */ \ -\ - c_use = c_begin; \ - p_use = p_begin; \ -\ - panel_len_i = panel_len_full; \ - panel_len_max_i = panel_len_max; \ -\ - is_p_use = is_p; \ -\ - /* The definition of bli_packm_my_iter() will depend on whether slab - or round-robin partitioning was requested at configure-time. */ \ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ - { \ - packm_ker_cast( strucc, \ - diagoffc_i, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - *m_panel_use, \ - *n_panel_use, \ - *m_panel_max, \ - *n_panel_max, \ - kappa_cast, \ - c_use, rs_c, cs_c, \ - p_use, rs_p, cs_p, \ - is_p_use, \ - cntx ); \ - } \ -\ - p_inc = ps_p; \ - } \ - else \ - { \ - /* This case executes if the panel is general, or, if the - panel is part of a triangular matrix and is neither unstored - (ie: zero) nor diagonal-intersecting. */ \ -\ - c_use = c_begin; \ - p_use = p_begin; \ -\ - panel_len_i = panel_len_full; \ - panel_len_max_i = panel_len_max; \ -\ - is_p_use = is_p; \ -\ - /* The definition of bli_packm_my_iter() will depend on whether slab - or round-robin partitioning was requested at configure-time. */ \ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ - { \ - packm_ker_cast( BLIS_GENERAL, \ - 0, \ - diagc, \ - BLIS_DENSE, \ - conjc, \ - schema, \ - invdiag, \ - *m_panel_use, \ - *n_panel_use, \ - *m_panel_max, \ - *n_panel_max, \ - kappa_cast, \ - c_use, rs_c, cs_c, \ - p_use, rs_p, cs_p, \ - is_p_use, \ - cntx ); \ - } \ -\ - /* NOTE: This value is equivalent to ps_p. */ \ - p_inc = ps_p; \ - } \ -\ - p_begin += p_inc; \ -\ - } \ -} + char* p_begin = p_cast; -INSERT_GENTFUNCR_BASIC( packm, packm_blk_var1 ) + // Iterate over every logical micropanel in the source matrix. + for ( dim_t ic = ic0, ip = ip0, it = 0; it < n_iter; + ic += ic_inc, ip += ip_inc, it += 1 ) + { + dim_t panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); + dim_t panel_dim_off_i = panel_dim_off + ic; + + doff_t diagoffc_i = diagoffc + (ip )*diagoffc_inc; + char* c_begin = c_cast + (ic )*incc*dt_c_size; + + inc_t p_inc = ps_p; + + // NOTE: We MUST use round-robin partitioning when packing + // micropanels of a triangular matrix. Hermitian/symmetric + // and general packing may use slab or round-robin, depending + // on which was selected at configure-time. + // The definition of bli_packm_my_iter() will depend on whether slab + // or round-robin partitioning was requested at configure-time. + bool my_iter = bli_is_triangular( strucc ) + ? bli_packm_my_iter_rr( it, it_start, it_end, tid, nt ) + : bli_packm_my_iter ( it, it_start, it_end, tid, nt ); + + if ( bli_is_triangular( strucc ) && + bli_is_unstored_subpart_n( diagoffc_i, uploc, panel_dim_i, panel_len_full ) ) + { + // This case executes if the panel belongs to a triangular + // matrix AND is completely unstored (ie: zero). If the panel + // is unstored, we do nothing. (Notice that we don't even + // increment p_begin.) + continue; + } + else if ( bli_is_triangular( strucc ) && + bli_intersects_diag_n( diagoffc_i, panel_dim_i, panel_len_full ) ) + { + // This case executes if the panel belongs to a triangular + // matrix AND is diagonal-intersecting. Notice that we + // cannot bury the following conditional logic into + // packm_struc_cxk() because we need to know the value of + // panel_len_max_i so we can properly increment p_inc. + + // Sanity check. Diagonals should not intersect the short end of + // a micro-panel. If they do, then somehow the constraints on + // cache blocksizes being a whole multiple of the register + // blocksizes was somehow violated. + if ( diagoffc_i < 0 ) + bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); + + dim_t panel_off_i; + dim_t panel_len_i; + dim_t panel_len_max_i; + + if ( bli_is_lower( uploc ) ) + { + panel_off_i = 0; + panel_len_i = bli_abs( diagoffc_i ) + panel_dim_i; + panel_len_max_i = bli_min( bli_abs( diagoffc_i ) + panel_dim_max, + panel_len_max ); + } + else // if ( bli_is_upper( uploc ) ) + { + panel_off_i = bli_abs( diagoffc_i ); + panel_len_i = panel_len_full - panel_off_i; + panel_len_max_i = panel_len_max - panel_off_i; + } + + dim_t panel_len_off_i = panel_off_i + panel_len_off; + + char* c_use = c_begin + (panel_off_i )*ldc*dt_c_size; + char* p_use = p_begin; + + // We need to re-compute the imaginary stride as a function of + // panel_len_max_i since triangular packed matrices have panels + // of varying lengths. NOTE: This imaginary stride value is + // only referenced by the packm kernels for induced methods. + inc_t is_p_use = ldp * panel_len_max_i; + + // We nudge the imaginary stride up by one if it is odd. + is_p_use += ( bli_is_odd( is_p_use ) ? 1 : 0 ); + + if ( my_iter ) + { + packm_ker_cast( strucc, + diagc, + uploc, + conjc, + schema, + invdiag, + panel_dim_i, + panel_len_i, + panel_dim_max, + panel_len_max_i, + panel_dim_off_i, + panel_len_off_i, + kappa_cast, + c_use, incc, ldc, + p_use, ldp, + is_p_use, + cntx, + params ); + } + + // NOTE: This value is usually LESS than ps_p because triangular + // matrices usually have several micro-panels that are shorter + // than a "full" micro-panel. + p_inc = is_p_use; + } + else + { + // This case executes if the panel is either dense, or belongs + // to a Hermitian or symmetric matrix, which includes stored, + // unstored, and diagonal-intersecting panels. + + if ( my_iter ) + { + packm_ker_cast( bli_is_triangular( strucc ) ? BLIS_GENERAL : strucc, + diagc, + uploc, + conjc, + schema, + invdiag, + panel_dim_i, + panel_len_full, + panel_dim_max, + panel_len_max, + panel_dim_off_i, + panel_len_off, + kappa_cast, + c_begin, incc, ldc, + p_begin, ldp, is_p, + cntx, + params ); + } + } + p_begin += p_inc*dt_p_size; + } +} -/* -if ( row_stored ) \ -PASTEMAC(ch,fprintm)( stdout, "packm_var2: b", m, n, \ - c_cast, rs_c, cs_c, "%4.1f", "" ); \ -if ( col_stored ) \ -PASTEMAC(ch,fprintm)( stdout, "packm_var2: a", m, n, \ - c_cast, rs_c, cs_c, "%4.1f", "" ); \ -*/ -/* -if ( row_stored ) \ -PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: b packed", *m_panel_max, *n_panel_max, \ - p_use, rs_p, cs_p, "%5.2f", "" ); \ -else \ -PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: a packed", *m_panel_max, *n_panel_max, \ - p_use, rs_p, cs_p, "%5.2f", "" ); \ -*/ \ -\ -/* -if ( col_stored ) { \ - if ( bli_thread_work_id( thread ) == 0 ) \ - { \ - printf( "packm_blk_var1: thread %lu (a = %p, ap = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ - fflush( stdout ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: a", *m_panel_use, *n_panel_use, \ - ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: ap", *m_panel_max, *n_panel_max, \ - ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - fflush( stdout ); \ - } \ -bli_thread_obarrier( thread ); \ - if ( bli_thread_work_id( thread ) == 1 ) \ - { \ - printf( "packm_blk_var1: thread %lu (a = %p, ap = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ - fflush( stdout ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: a", *m_panel_use, *n_panel_use, \ - ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: ap", *m_panel_max, *n_panel_max, \ - ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - fflush( stdout ); \ - } \ -bli_thread_obarrier( thread ); \ -} \ -else { \ - if ( bli_thread_work_id( thread ) == 0 ) \ - { \ - printf( "packm_blk_var1: thread %lu (b = %p, bp = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ - fflush( stdout ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: b", *m_panel_use, *n_panel_use, \ - ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: bp", *m_panel_max, *n_panel_max, \ - ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - fflush( stdout ); \ - } \ -bli_thread_obarrier( thread ); \ - if ( bli_thread_work_id( thread ) == 1 ) \ - { \ - printf( "packm_blk_var1: thread %lu (b = %p, bp = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ - fflush( stdout ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: b", *m_panel_use, *n_panel_use, \ - ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ - PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: bp", *m_panel_max, *n_panel_max, \ - ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - fflush( stdout ); \ - } \ -bli_thread_obarrier( thread ); \ -} \ -*/ -/* - if ( bli_is_4mi_packed( schema ) ) { \ - printf( "packm_var2: is_p_use = %lu\n", is_p_use ); \ - if ( col_stored ) { \ - if ( 0 ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_r", *m_panel_use, *n_panel_use, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_p_use, rs_p, cs_p, "%4.1f", "" ); \ - } \ - if ( row_stored ) { \ - if ( 0 ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_r", *m_panel_use, *n_panel_use, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_p_use, rs_p, cs_p, "%4.1f", "" ); \ - } \ - } \ -*/ -/* - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_rpi", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ -*/ -/* - if ( row_stored ) { \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_i", *m_panel_max, *n_panel_max, \ - (( ctype_r* )c_use)+rs_c, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - inc_t is_b = rs_p * *m_panel_max; \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_b, rs_p, cs_p, "%4.1f", "" ); \ - } \ -*/ -/* - if ( col_stored ) { \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_i", *m_panel_max, *n_panel_max, \ - (( ctype_r* )c_use)+rs_c, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + p_inc, rs_p, cs_p, "%4.1f", "" ); \ - } \ -*/ diff --git a/frame/1m/packm/bli_packm_blk_var1.h b/frame/1m/packm/bli_packm_blk_var1.h new file mode 100644 index 0000000000..9cda5828b5 --- /dev/null +++ b/frame/1m/packm/bli_packm_blk_var1.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// packm params types. +// + +typedef struct +{ + // Type of C Type of P + packm_ker_vft ukr_fn[BLIS_NUM_FP_TYPES][BLIS_NUM_FP_TYPES]; +} packm_blk_var1_params_t; + +// +// Prototype object-based interfaces. +// + +BLIS_EXPORT_BLIS void bli_packm_blk_var1 + ( + obj_t* c, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* t + ); + diff --git a/frame/1m/packm/bli_packm_blk_var1_md.c b/frame/1m/packm/bli_packm_blk_var1_md.c deleted file mode 100644 index 8d4906c50e..0000000000 --- a/frame/1m/packm/bli_packm_blk_var1_md.c +++ /dev/null @@ -1,344 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#ifdef BLIS_ENABLE_GEMM_MD - -#define FUNCPTR_T packm_fp - -typedef void (*FUNCPTR_T)( - trans_t transc, - pack_t schema, - dim_t m, - dim_t n, - dim_t m_max, - dim_t n_max, - void* kappa, - void* c, inc_t rs_c, inc_t cs_c, - void* p, inc_t rs_p, inc_t cs_p, - inc_t is_p, - dim_t pd_p, inc_t ps_p, - cntx_t* cntx, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY2_ALL(ftypes,packm_blk_var1_md); - - -void bli_packm_blk_var1_md - ( - obj_t* c, - obj_t* p, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* t - ) -{ - num_t dt_c = bli_obj_dt( c ); - num_t dt_p = bli_obj_dt( p ); - - trans_t transc = bli_obj_conjtrans_status( c ); - pack_t schema = bli_obj_pack_schema( p ); - - dim_t m_p = bli_obj_length( p ); - dim_t n_p = bli_obj_width( p ); - dim_t m_max_p = bli_obj_padded_length( p ); - dim_t n_max_p = bli_obj_padded_width( p ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - void* buf_p = bli_obj_buffer_at_off( p ); - inc_t rs_p = bli_obj_row_stride( p ); - inc_t cs_p = bli_obj_col_stride( p ); - inc_t is_p = bli_obj_imag_stride( p ); - dim_t pd_p = bli_obj_panel_dim( p ); - inc_t ps_p = bli_obj_panel_stride( p ); - - obj_t kappa; - void* buf_kappa; - - FUNCPTR_T f; - - - // Treatment of kappa (ie: packing during scaling) depends on - // whether we are executing an induced method. - if ( bli_is_nat_packed( schema ) ) - { - // This branch is for native execution, where we assume that - // the micro-kernel will always apply the alpha scalar of the - // higher-level operation. Thus, we use BLIS_ONE for kappa so - // that the underlying packm implementation does not perform - // any scaling during packing. - buf_kappa = bli_obj_buffer_for_const( dt_p, &BLIS_ONE ); - } - else // if ( bli_is_ind_packed( schema ) ) - { - obj_t* kappa_p; - - // The value for kappa we use will depend on whether the scalar - // attached to A has a nonzero imaginary component. If it does, - // then we will apply the scalar during packing to facilitate - // implementing induced complex domain algorithms in terms of - // real domain micro-kernels. (In the aforementioned situation, - // applying a real scalar is easy, but applying a complex one is - // harder, so we avoid the need altogether with the code below.) - if ( bli_obj_scalar_has_nonzero_imag( p ) ) - { - // Detach the scalar. - bli_obj_scalar_detach( p, &kappa ); - - // Reset the attached scalar (to 1.0). - bli_obj_scalar_reset( p ); - - kappa_p = κ - } - else - { - // If the internal scalar of A has only a real component, then - // we will apply it later (in the micro-kernel), and so we will - // use BLIS_ONE to indicate no scaling during packing. - kappa_p = &BLIS_ONE; - } - - // Acquire the buffer to the kappa chosen above. - buf_kappa = bli_obj_buffer_for_1x1( dt_p, kappa_p ); - } - - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_c][dt_p]; - - // Invoke the function. - f( - transc, - schema, - m_p, - n_p, - m_max_p, - n_max_p, - buf_kappa, - buf_c, rs_c, cs_c, - buf_p, rs_p, cs_p, - is_p, - pd_p, ps_p, - cntx, - t ); -} - - -#undef GENTFUNC2 -#define GENTFUNC2( ctype_c, ctype_p, chc, chp, varname ) \ -\ -void PASTEMAC2(chc,chp,varname) \ - ( \ - trans_t transc, \ - pack_t schema, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - dim_t pd_p, inc_t ps_p, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ) \ -{ \ - ctype_p* restrict kappa_cast = kappa; \ - ctype_c* restrict c_cast = c; \ - ctype_p* restrict p_cast = p; \ - ctype_c* restrict c_begin; \ - ctype_p* restrict p_begin; \ -\ - dim_t iter_dim; \ - dim_t n_iter; \ - dim_t it, ic, ip; \ - doff_t ic_inc, ip_inc; \ - dim_t panel_len_full; \ - dim_t panel_len_i; \ - dim_t panel_len_max; \ - dim_t panel_len_max_i; \ - dim_t panel_dim_i; \ - dim_t panel_dim_max; \ - inc_t vs_c; \ - inc_t p_inc; \ - dim_t* m_panel_use; \ - dim_t* n_panel_use; \ - dim_t* m_panel_max; \ - dim_t* n_panel_max; \ - conj_t conjc; \ - bool_t row_stored; \ - bool_t col_stored; \ -\ - ctype_c* restrict c_use; \ - ctype_p* restrict p_use; \ -\ -\ - /* Extract the conjugation bit from the transposition argument. */ \ - conjc = bli_extract_conj( transc ); \ -\ - /* If c needs a transposition, induce it so that we can more simply - express the remaining parameters and code. */ \ - if ( bli_does_trans( transc ) ) \ - { \ - bli_swap_incs( &rs_c, &cs_c ); \ - bli_toggle_trans( &transc ); \ - } \ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ - ( void )col_stored; \ -\ - /* If the row storage flag indicates row storage, then we are packing - to column panels; otherwise, if the strides indicate column storage, - we are packing to row panels. */ \ - if ( row_stored ) \ - { \ - /* Prepare to pack to row-stored column panels. */ \ - iter_dim = n; \ - panel_len_full = m; \ - panel_len_max = m_max; \ - panel_dim_max = pd_p; \ - vs_c = cs_c; \ - m_panel_use = &panel_len_i; \ - n_panel_use = &panel_dim_i; \ - m_panel_max = &panel_len_max_i; \ - n_panel_max = &panel_dim_max; \ - } \ - else /* if ( col_stored ) */ \ - { \ - /* Prepare to pack to column-stored row panels. */ \ - iter_dim = m; \ - panel_len_full = n; \ - panel_len_max = n_max; \ - panel_dim_max = pd_p; \ - vs_c = rs_c; \ - m_panel_use = &panel_dim_i; \ - n_panel_use = &panel_len_i; \ - m_panel_max = &panel_dim_max; \ - n_panel_max = &panel_len_max_i; \ - } \ -\ - /* Compute the total number of iterations we'll need. */ \ - n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ -\ - { \ - ic_inc = panel_dim_max; \ - ip_inc = 1; \ - } \ -\ - p_begin = p_cast; \ -\ - /* Query the number of threads and thread ids from the current thread's - packm thrinfo_t node. */ \ - const dim_t nt = bli_thread_n_way( thread ); \ - const dim_t tid = bli_thread_work_id( thread ); \ -\ - /* Suppress unused variable warnings when slab partitioning is enabled, - since the slab-based definition of bli_packm_my_iter() does not - actually use tid or nt. */ \ - ( void )nt; ( void )tid; \ -\ - dim_t it_start, it_end, it_inc; \ -\ - /* Determine the thread range and increment using the current thread's - packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() - will depend on whether slab or round-robin partitioning was requested - at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ -\ - for ( ic = 0, ip = 0, it = 0; it < n_iter; \ - ic += ic_inc, ip += ip_inc, it += 1 ) \ - { \ - panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); \ -\ - c_begin = c_cast + (ic )*vs_c; \ -\ - { \ - c_use = c_begin; \ - p_use = p_begin; \ -\ - panel_len_i = panel_len_full; \ - panel_len_max_i = panel_len_max; \ -\ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ - { \ - PASTEMAC2(chc,chp,packm_struc_cxk_md) \ - ( \ - conjc, \ - schema, \ - *m_panel_use, \ - *n_panel_use, \ - *m_panel_max, \ - *n_panel_max, \ - kappa_cast, \ - c_use, rs_c, cs_c, \ - p_use, rs_p, cs_p, \ - is_p, \ - cntx \ - ); \ - } \ -\ - p_inc = ps_p; \ - } \ -\ -/* -if ( row_stored ) \ -PASTEMAC(chp,fprintm)( stdout, "packm_blk_var1_md: b packed", *m_panel_max, *n_panel_max, \ - p_use, rs_p, cs_p, "%5.2f", "" ); \ -else \ -PASTEMAC(chp,fprintm)( stdout, "packm_blk_var1_md: a packed", *m_panel_max, *n_panel_max, \ - p_use, rs_p, cs_p, "%5.2f", "" ); \ -*/ \ -\ - p_begin += p_inc; \ -\ - } \ -} - -INSERT_GENTFUNC2_BASIC0( packm_blk_var1_md ) -INSERT_GENTFUNC2_MIXDP0( packm_blk_var1_md ) - -#endif diff --git a/frame/1m/packm/bli_packm_cntl.c b/frame/1m/packm/bli_packm_cntl.c index 12083f3be1..e99ed9cf3d 100644 --- a/frame/1m/packm/bli_packm_cntl.c +++ b/frame/1m/packm/bli_packm_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,16 +35,15 @@ #include "blis.h" -cntl_t* bli_packm_cntl_create_node +BLIS_EXPORT_BLIS cntl_t* bli_packm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* packm_var_func, + void_fp var_func, bszid_t bmid_m, bszid_t bmid_n, - bool_t does_invert_diag, - bool_t rev_iter_if_upper, - bool_t rev_iter_if_lower, + bool does_invert_diag, + bool rev_iter_if_upper, + bool rev_iter_if_lower, pack_t pack_schema, packbuf_t pack_buf_type, cntl_t* sub_node @@ -62,7 +61,6 @@ cntl_t* bli_packm_cntl_create_node // Initialize the packm_params_t struct. params->size = sizeof( packm_params_t ); - params->var_func = packm_var_func; params->bmid_m = bmid_m; params->bmid_n = bmid_n; params->does_invert_diag = does_invert_diag; diff --git a/frame/1m/packm/bli_packm_cntl.h b/frame/1m/packm/bli_packm_cntl.h index fef603ab0e..14bfe1ce85 100644 --- a/frame/1m/packm/bli_packm_cntl.h +++ b/frame/1m/packm/bli_packm_cntl.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,53 +36,47 @@ struct packm_params_s { uint64_t size; // size field must be present and come first. - packm_var_oft var_func; bszid_t bmid_m; bszid_t bmid_n; - bool_t does_invert_diag; - bool_t rev_iter_if_upper; - bool_t rev_iter_if_lower; + bool does_invert_diag; + bool rev_iter_if_upper; + bool rev_iter_if_lower; pack_t pack_schema; packbuf_t pack_buf_type; }; typedef struct packm_params_s packm_params_t; -static packm_var_oft bli_cntl_packm_params_var_func( cntl_t* cntl ) -{ - packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->var_func; -} - -static bszid_t bli_cntl_packm_params_bmid_m( cntl_t* cntl ) +BLIS_INLINE bszid_t bli_cntl_packm_params_bmid_m( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->bmid_m; } -static bszid_t bli_cntl_packm_params_bmid_n( cntl_t* cntl ) +BLIS_INLINE bszid_t bli_cntl_packm_params_bmid_n( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->bmid_n; } -static bool_t bli_cntl_packm_params_does_invert_diag( cntl_t* cntl ) +BLIS_INLINE bool bli_cntl_packm_params_does_invert_diag( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->does_invert_diag; } -static bool_t bli_cntl_packm_params_rev_iter_if_upper( cntl_t* cntl ) +BLIS_INLINE bool bli_cntl_packm_params_rev_iter_if_upper( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->rev_iter_if_upper; } -static bool_t bli_cntl_packm_params_rev_iter_if_lower( cntl_t* cntl ) +BLIS_INLINE bool bli_cntl_packm_params_rev_iter_if_lower( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->rev_iter_if_lower; } -static pack_t bli_cntl_packm_params_pack_schema( cntl_t* cntl ) +BLIS_INLINE pack_t bli_cntl_packm_params_pack_schema( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->pack_schema; } -static packbuf_t bli_cntl_packm_params_pack_buf_type( cntl_t* cntl ) +BLIS_INLINE packbuf_t bli_cntl_packm_params_pack_buf_type( cntl_t* cntl ) { packm_params_t* ppp = ( packm_params_t* )cntl->params; return ppp->pack_buf_type; } @@ -92,13 +86,12 @@ static packbuf_t bli_cntl_packm_params_pack_buf_type( cntl_t* cntl ) cntl_t* bli_packm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* packm_var_func, + void_fp var_func, bszid_t bmid_m, bszid_t bmid_n, - bool_t does_invert_diag, - bool_t rev_iter_if_upper, - bool_t rev_iter_if_lower, + bool does_invert_diag, + bool rev_iter_if_upper, + bool rev_iter_if_lower, pack_t pack_schema, packbuf_t pack_buf_type, cntl_t* sub_node diff --git a/frame/1m/packm/bli_packm_cxk.c b/frame/1m/packm/bli_packm_cxk.c index 59f99dd18f..ea0418cae5 100644 --- a/frame/1m/packm/bli_packm_cxk.c +++ b/frame/1m/packm/bli_packm_cxk.c @@ -40,6 +40,7 @@ void PASTEMAC(ch,opname) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t panel_dim, \ dim_t panel_dim_max, \ dim_t panel_len, \ @@ -73,36 +74,52 @@ void PASTEMAC(ch,opname) \ the outer (panel_dim_max - panel_dim) rows or columns of the micropanel. (Note that these rows/columns correspond to values beyond the edge of matrix A.) The kernel intrinsically knows its - own panel_dim_max, since that corresponds to the kernel's register - blocksize. However, we need to pass in panel_len_max because the - bottom-right edge case of trsm_lu will need all elements above the - extended diagonal and beyond (to the right of) the bottom-right - element to be initialized to zero so the trsm portion of the - computational kernel will operate with zeros for those iterations. + own panel_dim_max, since that corresponds to the packm micropanel's + normal width (corresponding to the gemm microkernel's register + blocksize (mr or nr). However, we *do* need to pass in panel_len_max + because the bottom-right edge case of trsm_lu will need all + elements above the extended diagonal and beyond (to the right of) + the bottom-right element to be initialized to zero so the trsm + portion of the computational kernel will operate with zeros for + those iterations. - As an example, if trsm_lu is executed on a 6x6 matrix, and the - gemmtrsm kernel uses MR = 6, the computation will begin with the - edge case, which is the bottom 2x2 matrix marked with x's. Code - in bli_packm_tri_cxk() will extend the diagonal as identity into - the remaining portion of the micropanel. But before that happens, - the packm kernel must have set the 0's shown below. (Unreferenced - elements are marked with '.'.) + For example, if trsm_lu is executed on an 10x10 triangular matrix, + and the gemmtrsm kernel uses MR = 6, the computation will begin + with the edge case, which is the bottom-right 4x4 upper triangular + matrix. Code in bli_packm_tri_cxk() will extend the diagonal as + identity into the remaining portion of the micropanel. But before + that happens, the packm kernel must have set the 0's added in + step (3) below. - x x 0 0 0 0 - . x 0 0 0 0 - . . 1 0 0 0 - . . . 1 0 0 - . . . . 1 0 - . . . . . 1 + packm kernel packm kernel packm kernel packm_tri_cxk + step 1: step 2: step 3: step 4: - In this case, panel_dim will be 2 because two rows of data are - copied from A, panel_len will be 2 because those two rows span - two columns of A, and panel_len_max will be 6 because there are a - total of 6 columns that can be written to, 4 of which lie beyond - the values copied from A. */ \ + x x x x . . x x x x . . x x x x 0 0 x x x x 0 0 + ? x x x . . ? x x x . . ? x x x 0 0 ? x x x 0 0 + ? ? x x . . -> ? ? x x . . -> ? ? x x 0 0 -> ? ? x x 0 0 + ? ? ? x . . ? ? ? x . . ? ? ? x 0 0 ? ? ? x 0 0 + . . . . . . 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + . . . . . . 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 + + x Copied from A; valid element. + ? Copied from A, but value is unknown and unused. + . Uninitialized. + 0 Initialized to zero. + 1 Initialized to one. + + NOTE: In step 5 (not shown), bli_packm_tri_cxk() sets the ?'s + to zero. This is not needed to support trsm, but rather to + support trmm. (Both use the same packing format and code.) + + In this case, panel_dim will be 4 because four rows of data are + copied from A, panel_len will be 4 because those four rows span + four columns of A, and panel_len_max will be 6 because there are a + total of 6 columns that can be written to in the packed micropanel, + 2 of which lie beyond the values copied from A. */ \ f \ ( \ conja, \ + schema, \ panel_dim, \ panel_len, \ panel_len_max, \ diff --git a/frame/1m/packm/bli_packm_cxk.h b/frame/1m/packm/bli_packm_cxk.h index be089f05c9..1402a53c91 100644 --- a/frame/1m/packm/bli_packm_cxk.h +++ b/frame/1m/packm/bli_packm_cxk.h @@ -39,6 +39,7 @@ void PASTEMAC(ch,varname) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t panel_dim, \ dim_t panel_dim_max, \ dim_t panel_len, \ diff --git a/frame/1m/packm/bli_packm_cxk_4mi.c b/frame/1m/packm/bli_packm_cxk_4mi.c deleted file mode 100644 index c22f551cca..0000000000 --- a/frame/1m/packm/bli_packm_cxk_4mi.c +++ /dev/null @@ -1,146 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - conj_t conja, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Note that we use panel_dim_max, not panel_dim, to query the packm - kernel function pointer. This means that we always use the same - kernel, even for edge cases. */ \ - num_t dt = PASTEMAC(ch,type); \ - l1mkr_t ker_id = panel_dim_max; \ -\ - PASTECH2(ch,opname,_ker_ft) f; \ -\ - /* Query the context for the packm kernel corresponding to the current - panel dimension, or kernel id. If the id is invalid, the function will - return NULL. */ \ - f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ -\ - /* If there exists a kernel implementation for the micro-panel dimension - provided, we invoke the implementation. Otherwise, we use scal2m. */ \ - if ( f != NULL ) \ - { \ - f \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - panel_len_max, \ - kappa, \ - a, inca, lda, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Treat the micro-panel as panel_dim x panel_len and column-stored - (unit row stride). */ \ -\ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ - if ( panel_dim != panel_dim_max ) \ - { \ - const dim_t i = panel_dim; \ - const dim_t m_edge = panel_dim_max - i; \ - const dim_t n_edge = panel_len_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -\ - /* If panel_len < panel_len_max, then we zero those unused columns. */ \ - if ( panel_len != panel_len_max ) \ - { \ - const dim_t j = panel_len; \ - const dim_t m_edge = panel_dim_max; \ - const dim_t n_edge = panel_len_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC0( packm_cxk_4mi ) - diff --git a/frame/1m/packm/bli_packm_init.c b/frame/1m/packm/bli_packm_init.c index 0e749efe69..5a7d716fe6 100644 --- a/frame/1m/packm/bli_packm_init.c +++ b/frame/1m/packm/bli_packm_init.c @@ -35,12 +35,14 @@ #include "blis.h" -siz_t bli_packm_init +bool bli_packm_init ( - obj_t* a, + obj_t* c, obj_t* p, cntx_t* cntx, - cntl_t* cntl + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread ) { bli_init_once(); @@ -51,185 +53,27 @@ siz_t bli_packm_init // suitable block of memory from the memory allocator (if such a block // of memory has not already been allocated previously). - bszid_t bmult_id_m; - bszid_t bmult_id_n; - bool_t does_invert_diag; - bool_t rev_iter_if_upper; - bool_t rev_iter_if_lower; - pack_t schema; - //packbuf_t pack_buf_type; - siz_t size_needed; - // Check parameters. if ( bli_error_checking_is_enabled() ) - bli_packm_init_check( a, p, cntx ); + bli_packm_init_check( c, p, cntx ); - // Extract various fields from the control tree. - bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); - bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); - does_invert_diag = bli_cntl_packm_params_does_invert_diag( cntl ); - rev_iter_if_upper = bli_cntl_packm_params_rev_iter_if_upper( cntl ); - rev_iter_if_lower = bli_cntl_packm_params_rev_iter_if_lower( cntl ); - schema = bli_cntl_packm_params_pack_schema( cntl ); - //pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); - -#if 0 - // Let us now check to see if the object has already been packed. First - // we check if it has been packed to an unspecified (row or column) - // format, in which case we can alias the object and return. - // NOTE: The reason we don't need to even look at the control tree in - // this case is as follows: an object's pack status is only set to - // BLIS_PACKED_UNSPEC for situations when the actual format used is - // not important, as long as its packed into contiguous rows or - // contiguous columns. A good example of this is packing for matrix - // operands in the level-2 operations. - if ( bli_obj_pack_schema( a ) == BLIS_PACKED_UNSPEC ) - { - bli_obj_alias_to( a, p ); - return 0; - } - - // Now we check if the object has already been packed to the desired - // schema (as encoded in the control tree). If so, we can alias and - // return 0. - // NOTE: In most cases, an object's pack status will be BLIS_NOT_PACKED - // and thus packing will be called for (but in some cases packing has - // already taken place, or does not need to take place, and so that will - // be indicated by the pack status). Also, not all combinations of - // current pack status and desired pack schema are valid. - if ( bli_obj_pack_schema( a ) == pack_schema ) - { - bli_obj_alias_to( a, p ); - return 0; - } -#endif + // We begin by copying the fields of A. + bli_obj_alias_to( c, p ); // If the object is marked as being filled with zeros, then we can skip // the packm operation entirely and alias. - if ( bli_obj_is_zeros( a ) ) - { - bli_obj_alias_to( a, p ); - return 0; - } - -#if 0 - pack_t schema; - - if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - // We now ignore the pack_schema field in the control tree and - // extract the schema from the context, depending on whether we are - // preparing to pack a block of A or panel of B. For A and B, we must - // obtain the schema from the context since the induced methods reuse - // the same control trees used by native execution, and those induced - // methods specify the schema used by the current execution phase - // within the context (whereas the control tree does not change). - - if ( pack_buf_type == BLIS_BUFFER_FOR_A_BLOCK ) - { - schema = bli_cntx_schema_a_block( cntx ); - } - else if ( pack_buf_type == BLIS_BUFFER_FOR_B_PANEL ) - { - schema = bli_cntx_schema_b_panel( cntx ); - } - else // if ( pack_buf_type == BLIS_BUFFER_FOR_C_PANEL ) - { - schema = bli_cntl_packm_params_pack_schema( cntl ); - } - } - else // ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - // For native execution, we obtain the schema from the control tree - // node. (Notice that it doesn't matter if the pack_buf_type is for - // A or B.) - schema = bli_cntl_packm_params_pack_schema( cntl ); - } - // This is no longer needed now that we branch between native and - // non-native cases above. -#if 0 - if ( pack_buf_type == BLIS_BUFFER_FOR_C_PANEL ) - { - // If we get a request to pack C for some reason, it is likely - // not part of an induced method, and so it would be safe (and - // necessary) to read the pack schema from the control tree. - schema = bli_cntl_packm_params_pack_schema( cntl ); - } -#endif -#endif - - // Prepare a few other variables based on properties of the control - // tree. - - invdiag_t invert_diag; - packord_t pack_ord_if_up; - packord_t pack_ord_if_lo; - - if ( does_invert_diag ) invert_diag = BLIS_INVERT_DIAG; - else invert_diag = BLIS_NO_INVERT_DIAG; - - if ( rev_iter_if_upper ) pack_ord_if_up = BLIS_PACK_REV_IF_UPPER; - else pack_ord_if_up = BLIS_PACK_FWD_IF_UPPER; - - if ( rev_iter_if_lower ) pack_ord_if_lo = BLIS_PACK_REV_IF_LOWER; - else pack_ord_if_lo = BLIS_PACK_FWD_IF_LOWER; - - // Initialize object p for the final packed matrix. - size_needed - = - bli_packm_init_pack - ( - invert_diag, - schema, - pack_ord_if_up, - pack_ord_if_lo, - bmult_id_m, - bmult_id_n, - a, - p, - cntx - ); - - // Return the size needed for memory allocation of the packed buffer. - return size_needed; -} - - -siz_t bli_packm_init_pack - ( - invdiag_t invert_diag, - pack_t schema, - packord_t pack_ord_if_up, - packord_t pack_ord_if_lo, - bszid_t bmult_id_m, - bszid_t bmult_id_n, - obj_t* a, - obj_t* p, - cntx_t* cntx - ) -{ - bli_init_once(); - - num_t dt_tar = bli_obj_target_dt( a ); - num_t dt_scalar = bli_obj_scalar_dt( a ); - trans_t transa = bli_obj_onlytrans_status( a ); - dim_t m_a = bli_obj_length( a ); - dim_t n_a = bli_obj_width( a ); - dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); - dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); - dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); - dim_t bmult_n_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_n, cntx ); + if ( bli_obj_is_zeros( c ) ) + return false; - dim_t m_p, n_p; - dim_t m_p_pad, n_p_pad; - siz_t size_p; - siz_t elem_size_p; - inc_t rs_p, cs_p; - inc_t is_p; - - - // We begin by copying the fields of A. - bli_obj_alias_to( a, p ); + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + num_t dt_tar = bli_obj_target_dt( c ); + num_t dt_scalar = bli_obj_scalar_dt( c ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); // Typecast the internal scalar value to the target datatype. // Note that if the typecasting is needed, this must happen BEFORE we @@ -241,51 +85,21 @@ siz_t bli_packm_init_pack // Update the storage datatype of P to be the target datatype of A. bli_obj_set_dt( dt_tar, p ); + bli_obj_set_elem_size( bli_dt_size( dt_tar ), p ); - // Update the dimension fields to explicitly reflect a transposition, - // if needed. - // Then, clear the conjugation and transposition fields from the object - // since matrix packing in BLIS is deemed to take care of all conjugation - // and transposition necessary. - // Then, we adjust the properties of P when A needs a transposition. - // We negate the diagonal offset, and if A is upper- or lower-stored, - // we either toggle the uplo of P. - // Finally, if we mark P as dense since we assume that all matrices, - // regardless of structure, will be densified. - bli_obj_set_dims_with_trans( transa, m_a, n_a, p ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, p ); - if ( bli_does_trans( transa ) ) - { - bli_obj_negate_diag_offset( p ); - if ( bli_obj_is_upper_or_lower( a ) ) - bli_obj_toggle_uplo( p ); - } + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); - // If we are packing micro-panels, mark P as dense. Otherwise, we are - // probably being called in the context of a level-2 operation, in - // which case we do not want to overwrite the uplo field of P (inherited - // from A) with BLIS_DENSE because that information may be needed by - // the level-2 operation's unblocked variant to decide whether to - // execute a "lower" or "upper" branch of code. - if ( bli_is_panel_packed( schema ) ) - { - bli_obj_set_uplo( BLIS_DENSE, p ); - } + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // Since we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); // Reset the view offsets to (0,0). bli_obj_set_offs( 0, 0, p ); - // Set the invert diagonal field. - bli_obj_set_invert_diag( invert_diag, p ); - - // Set the pack status of P to the pack schema prescribed in the control - // tree node. - bli_obj_set_pack_schema( schema, p ); - - // Set the packing order bits. - bli_obj_set_pack_order_if_upper( pack_ord_if_up, p ); - bli_obj_set_pack_order_if_lower( pack_ord_if_lo, p ); - // Compute the dimensions padded by the dimension multiples. These // dimensions will be the dimensions of the packed matrices, including // zero-padding, and will be used by the macro- and micro-kernels. @@ -293,10 +107,10 @@ siz_t bli_packm_init_pack // in P) and aligning them to the dimension multiples (typically equal // to register blocksizes). This does waste a little bit of space for // level-2 operations, but that's okay with us. - m_p = bli_obj_length( p ); - n_p = bli_obj_width( p ); - m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); - n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); // Save the padded dimensions into the packed object. It is important // to save these dimensions since they represent the actual dimensions @@ -304,258 +118,70 @@ siz_t bli_packm_init_pack bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); // Now we prepare to compute strides, align them, and compute the - // total number of bytes needed for the packed buffer. The caller - // will then use that value to acquire an appropriate block of memory - // from the memory allocator. + // total number of bytes needed for the packed buffer. Then we use + // that value to acquire an appropriate block of memory from the + // memory allocator. // Extract the element size for the packed object. - elem_size_p = bli_obj_elem_size( p ); - - // Set the row and column strides of p based on the pack schema. - if ( bli_is_row_packed( schema ) && - !bli_is_panel_packed( schema ) ) - { - // For regular row storage, the padded width of our matrix - // should be used for the row stride, with the column stride set - // to one. By using the WIDTH of the mem_t region, we allow for - // zero-padding (if necessary/desired) along the right edge of - // the matrix. - rs_p = n_p_pad; - cs_p = 1; - - // Align the leading dimension according to the heap stride - // alignment size so that the second, third, etc rows begin at - // aligned addresses. - rs_p = bli_align_dim_to_size( rs_p, elem_size_p, - BLIS_HEAP_STRIDE_ALIGN_SIZE ); - - // Store the strides in P. - bli_obj_set_strides( rs_p, cs_p, p ); - - // Compute the size of the packed buffer. - size_p = m_p_pad * rs_p * elem_size_p; - } - else if ( bli_is_col_packed( schema ) && - !bli_is_panel_packed( schema ) ) - { - // For regular column storage, the padded length of our matrix - // should be used for the column stride, with the row stride set - // to one. By using the LENGTH of the mem_t region, we allow for - // zero-padding (if necessary/desired) along the bottom edge of - // the matrix. - cs_p = m_p_pad; - rs_p = 1; - - // Align the leading dimension according to the heap stride - // alignment size so that the second, third, etc columns begin at - // aligned addresses. - cs_p = bli_align_dim_to_size( cs_p, elem_size_p, - BLIS_HEAP_STRIDE_ALIGN_SIZE ); - - // Store the strides in P. - bli_obj_set_strides( rs_p, cs_p, p ); - - // Compute the size of the packed buffer. - size_p = cs_p * n_p_pad * elem_size_p; - } - else if ( bli_is_row_packed( schema ) && - bli_is_panel_packed( schema ) ) - { - dim_t m_panel; - dim_t ps_p, ps_p_orig; - - // The panel dimension (for each datatype) should be equal to the - // default (logical) blocksize multiple in the m dimension. - m_panel = bmult_m_def; - - // The "column stride" of a row panel packed object is interpreted as - // the column stride WITHIN a panel. Thus, this is equal to the - // packing (storage) blocksize multiple (which may be equal to the - // default (logical) blocksize multiple. - cs_p = bmult_m_pack; - - // The "row stride" of a row panel packed object is interpreted - // as the row stride WITHIN a panel. Thus, it is unit. - rs_p = 1; - - // The "panel stride" of a panel packed object is interpreted as the - // distance between the (0,0) element of panel k and the (0,0) - // element of panel k+1. We use the padded width computed above to - // allow for zero-padding (if necessary/desired) along the far end - // of each panel (ie: the right edge of the matrix). Zero-padding - // can also occur along the long edge of the last panel if the m - // dimension of the matrix is not a whole multiple of MR. - ps_p = cs_p * n_p_pad; - - // As a general rule, we don't want panel strides to be odd. This - // is primarily motivated by our desire to support interleaved 3m - // micro-panels, in which case we have to scale the panel stride - // by 3/2. That division by 2 means the numerator (prior to being - // scaled by 3) must be even. - if ( bli_is_odd( ps_p ) ) ps_p += 1; - - // Preserve this early panel stride value for use later, if needed. - ps_p_orig = ps_p; - - // Here, we adjust the panel stride, if necessary. Remember: ps_p is - // always interpreted as being in units of the datatype of the object - // which is not necessarily how the micro-panels will be stored. For - // interleaved 3m, we will increase ps_p by 50%, and for ro/io/rpi, - // we halve ps_p. Why? Because the macro-kernel indexes in units of - // the complex datatype. So these changes "trick" it into indexing - // the correct amount. - if ( bli_is_3mi_packed( schema ) ) - { - ps_p = ( ps_p * 3 ) / 2; - } - else if ( bli_is_3ms_packed( schema ) || - bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ) - { - // The division by 2 below assumes that ps_p is an even number. - // However, it is possible that, at this point, ps_p is an odd. - // If it is indeed odd, we nudge it higher. - if ( bli_is_odd( ps_p ) ) ps_p += 1; - - // Despite the fact that the packed micro-panels will contain - // real elements, the panel stride that we store in the obj_t - // (which is passed into the macro-kernel) needs to be in units - // of complex elements, since the macro-kernel will index through - // micro-panels via complex pointer arithmetic for trmm/trsm. - // Since the indexing "increment" will be twice as large as each - // actual stored element, we divide the panel_stride by 2. - ps_p = ps_p / 2; - } - - // Set the imaginary stride (in units of fundamental elements) for - // 3m and 4m (separated or interleaved). We use ps_p_orig since - // that variable tracks the number of real part elements contained - // within each micro-panel of the source matrix. Therefore, this - // is the number of real elements that must be traversed before - // reaching the imaginary part (3mi/4mi) of the packed micro-panel, - // or the real part of the next micro-panel (3ms). - if ( bli_is_3mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_4mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_3ms_packed( schema ) ) is_p = ps_p_orig * ( m_p_pad / m_panel ); - else is_p = 1; - - // Store the strides and panel dimension in P. - bli_obj_set_strides( rs_p, cs_p, p ); - bli_obj_set_imag_stride( is_p, p ); - bli_obj_set_panel_dim( m_panel, p ); - bli_obj_set_panel_stride( ps_p, p ); - bli_obj_set_panel_length( m_panel, p ); - bli_obj_set_panel_width( n_p, p ); - - // Compute the size of the packed buffer. - size_p = ps_p * ( m_p_pad / m_panel ) * elem_size_p; - } - else if ( bli_is_col_packed( schema ) && - bli_is_panel_packed( schema ) ) - { - dim_t n_panel; - dim_t ps_p, ps_p_orig; - - // The panel dimension (for each datatype) should be equal to the - // default (logical) blocksize multiple in the n dimension. - n_panel = bmult_n_def; - - // The "row stride" of a column panel packed object is interpreted as - // the row stride WITHIN a panel. Thus, this is equal to the - // packing (storage) blocksize multiple (which may be equal to the - // default (logical) blocksize multiple. - rs_p = bmult_n_pack; - - // The "column stride" of a column panel packed object is interpreted - // as the column stride WITHIN a panel. Thus, it is unit. - cs_p = 1; - - // The "panel stride" of a panel packed object is interpreted as the - // distance between the (0,0) element of panel k and the (0,0) - // element of panel k+1. We use the padded length computed above to - // allow for zero-padding (if necessary/desired) along the far end - // of each panel (ie: the bottom edge of the matrix). Zero-padding - // can also occur along the long edge of the last panel if the n - // dimension of the matrix is not a whole multiple of NR. - ps_p = m_p_pad * rs_p; - - // As a general rule, we don't want panel strides to be odd. This - // is primarily motivated by our desire to support interleaved 3m - // micro-panels, in which case we have to scale the panel stride - // by 3/2. That division by 2 means the numerator (prior to being - // scaled by 3) must be even. - if ( bli_is_odd( ps_p ) ) ps_p += 1; - - // Preserve this early panel stride value for use later, if needed. - ps_p_orig = ps_p; - - // Here, we adjust the panel stride, if necessary. Remember: ps_p is - // always interpreted as being in units of the datatype of the object - // which is not necessarily how the micro-panels will be stored. For - // interleaved 3m, we will increase ps_p by 50%, and for ro/io/rpi, - // we halve ps_p. Why? Because the macro-kernel indexes in units of - // the complex datatype. So these changes "trick" it into indexing - // the correct amount. - if ( bli_is_3mi_packed( schema ) ) - { - ps_p = ( ps_p * 3 ) / 2; - } - else if ( bli_is_3ms_packed( schema ) || - bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ) - { - // The division by 2 below assumes that ps_p is an even number. - // However, it is possible that, at this point, ps_p is an odd. - // If it is indeed odd, we nudge it higher. - if ( bli_is_odd( ps_p ) ) ps_p += 1; - - // Despite the fact that the packed micro-panels will contain - // real elements, the panel stride that we store in the obj_t - // (which is passed into the macro-kernel) needs to be in units - // of complex elements, since the macro-kernel will index through - // micro-panels via complex pointer arithmetic for trmm/trsm. - // Since the indexing "increment" will be twice as large as each - // actual stored element, we divide the panel_stride by 2. - ps_p = ps_p / 2; - } - - // Set the imaginary stride (in units of fundamental elements) for - // 3m and 4m (separated or interleaved). We use ps_p_orig since - // that variable tracks the number of real part elements contained - // within each micro-panel of the source matrix. Therefore, this - // is the number of real elements that must be traversed before - // reaching the imaginary part (3mi/4mi) of the packed micro-panel, - // or the real part of the next micro-panel (3ms). - if ( bli_is_3mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_4mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_3ms_packed( schema ) ) is_p = ps_p_orig * ( n_p_pad / n_panel ); - else is_p = 1; - - // Store the strides and panel dimension in P. - bli_obj_set_strides( rs_p, cs_p, p ); - bli_obj_set_imag_stride( is_p, p ); - bli_obj_set_panel_dim( n_panel, p ); - bli_obj_set_panel_stride( ps_p, p ); - bli_obj_set_panel_length( m_p, p ); - bli_obj_set_panel_width( n_panel, p ); - - // Compute the size of the packed buffer. - size_p = ps_p * ( n_p_pad / n_panel ) * elem_size_p; - } - else - { - // NOTE: When implementing block storage, we only need to implement - // the following two cases: - // - row-stored blocks in row-major order - // - column-stored blocks in column-major order - // The other two combinations coincide with that of packed row-panel - // and packed column- panel storage. - - size_p = 0; - } - - return size_p; + siz_t elem_size_p = bli_obj_elem_size( p ); + + // The panel dimension (for each datatype) should be equal to the + // default (logical) blocksize multiple in the m dimension. + dim_t m_panel = bmult_m_def; + + // The "column stride" of a row-micropanel packed object is interpreted + // as the column stride WITHIN a micropanel. Thus, this is equal to the + // packing (storage) blocksize multiple, which may be equal to the + // default (logical) blocksize multiple). + inc_t cs_p = bmult_m_pack; + + // The "row stride" of a row-micropanel packed object is interpreted + // as the row stride WITHIN a micropanel. Thus, it is unit. + inc_t rs_p = 1; + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = cs_p * n_p_pad; + + // As a general rule, we don't want micropanel strides to be odd. There + // are very few instances where this can happen, but we've seen it happen + // more than zero times (such as for certain small problems), and so we + // check for it here. + if ( bli_is_odd( ps_p ) ) ps_p += 1; + + // Set the imaginary stride (in units of fundamental elements). + // This is the number of real elements that must be traversed before + // reaching the imaginary part of the packed micropanel. NOTE: the + // imaginary stride is mostly vestigial and left over from the 3m + // and 4m implementations. + inc_t is_p = 1; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( rs_p, cs_p, p ); + bli_obj_set_imag_stride( is_p, p ); + bli_obj_set_panel_dim( m_panel, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( m_panel, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * ( m_p_pad / m_panel ) * elem_size_p; + + // If the requested size is zero, then we don't need to do any allocation. + if ( size_p == 0 ) + return false; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + void* buffer = bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( buffer, p ); + + return true; } diff --git a/frame/1m/packm/bli_packm_init.h b/frame/1m/packm/bli_packm_init.h index 6896ab913e..152c6f15cd 100644 --- a/frame/1m/packm/bli_packm_init.h +++ b/frame/1m/packm/bli_packm_init.h @@ -32,24 +32,13 @@ */ -siz_t bli_packm_init +BLIS_EXPORT_BLIS bool bli_packm_init ( obj_t* a, obj_t* p, cntx_t* cntx, - cntl_t* cntl - ); - -siz_t bli_packm_init_pack - ( - invdiag_t invert_diag, - pack_t schema, - packord_t pack_ord_if_up, - packord_t pack_ord_if_lo, - bszid_t bmult_id_m, - bszid_t bmult_id_n, - obj_t* a, - obj_t* p, - cntx_t* cntx + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread ); diff --git a/frame/1m/packm/bli_packm_int.c b/frame/1m/packm/bli_packm_int.c index 6dc9ec85af..c9a2bb9db2 100644 --- a/frame/1m/packm/bli_packm_int.c +++ b/frame/1m/packm/bli_packm_int.c @@ -39,59 +39,19 @@ void bli_packm_int obj_t* a, obj_t* p, cntx_t* cntx, + rntm_t* rntm, cntl_t* cntl, thrinfo_t* thread ) { bli_init_once(); - packm_var_oft f; + // Extract the function pointer from the object. + packm_var_oft f = bli_obj_pack_fn( a ); - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_packm_int_check( a, p, cntx ); - - // Sanity check; A should never have a zero dimension. If we must support - // it, then we should fold it into the next alias-and-early-exit block. - //if ( bli_obj_has_zero_dim( a ) ) bli_abort(); - - // Let us now check to see if the object has already been packed. First - // we check if it has been packed to an unspecified (row or column) - // format, in which case we can return, since by now aliasing has already - // taken place in packm_init(). - // NOTE: The reason we don't need to even look at the control tree in - // this case is as follows: an object's pack status is only set to - // BLIS_PACKED_UNSPEC for situations when the actual format used is - // not important, as long as its packed into contiguous rows or - // contiguous columns. A good example of this is packing for matrix - // operands in the level-2 operations. - if ( bli_obj_pack_schema( a ) == BLIS_PACKED_UNSPEC ) - { - return; - } - - // At this point, we can be assured that cntl is not NULL. Now we check - // if the object has already been packed to the desired schema (as en- - // coded in the control tree). If so, we can return, as above. - // NOTE: In most cases, an object's pack status will be BLIS_NOT_PACKED - // and thus packing will be called for (but in some cases packing has - // already taken place, or does not need to take place, and so that will - // be indicated by the pack status). Also, not all combinations of - // current pack status and desired pack schema are valid. - if ( bli_obj_pack_schema( a ) == bli_cntl_packm_params_pack_schema( cntl ) ) - { - return; - } - - // If the object is marked as being filled with zeros, then we can skip - // the packm operation entirely. - if ( bli_obj_is_zeros( a ) ) - { - return; - } - - // Extract the function pointer from the current control tree node. - f = bli_cntl_packm_params_var_func( cntl ); + // Barrier so that we know threads are done with previous computation + // with the same packing buffer before starting to pack. + bli_thread_barrier( thread ); // Invoke the variant with kappa_use. f @@ -99,8 +59,12 @@ void bli_packm_int a, p, cntx, + rntm, cntl, thread ); + + // Barrier so that packing is done before computation. + bli_thread_barrier( thread ); } diff --git a/frame/1m/packm/bli_packm_int.h b/frame/1m/packm/bli_packm_int.h index 573a299d67..16a5c2c34d 100644 --- a/frame/1m/packm/bli_packm_int.h +++ b/frame/1m/packm/bli_packm_int.h @@ -37,6 +37,7 @@ void bli_packm_int obj_t* a, obj_t* p, cntx_t* cntx, + rntm_t* rntm, cntl_t* cntl, thrinfo_t* thread ); diff --git a/frame/1m/packm/bli_packm_scalar.c b/frame/1m/packm/bli_packm_scalar.c new file mode 100644 index 0000000000..f613028c93 --- /dev/null +++ b/frame/1m/packm/bli_packm_scalar.c @@ -0,0 +1,76 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2016, Hewlett Packard Enterprise Development LP + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void* bli_packm_scalar( obj_t* kappa, obj_t* p ) +{ + num_t dt_p = bli_obj_dt( p ); + pack_t schema = bli_obj_pack_schema( p ); + + // The value for kappa we use will depends on whether the scalar + // attached to A has a nonzero imaginary component. If it does, + // then we will apply the scalar during packing to facilitate + // implementing induced complex domain algorithms in terms of + // real domain micro-kernels. (In the aforementioned situation, + // applying a real scalar is easy, but applying a complex one is + // harder, so we avoid the need altogether with the code below.) + if ( bli_obj_scalar_has_nonzero_imag( p ) && + !bli_is_nat_packed( schema ) ) + { + //printf( "applying non-zero imag kappa\n_p" ); + + // Detach the scalar. + bli_obj_scalar_detach( p, kappa ); + + // Reset the attached scalar (to 1.0). + bli_obj_scalar_reset( p ); + + return bli_obj_buffer_for_1x1( dt_p, kappa ); + } + // This branch is also for native execution, where we assume that + // the micro-kernel will always apply the alpha scalar of the + // higher-level operation. Thus, we use BLIS_ONE for kappa so + // that the underlying packm implementation does not perform + // any scaling during packing. + else + { + // If the internal scalar of A has only a real component, then + // we will apply it later (in the micro-kernel), and so we will + // use BLIS_ONE to indicate no scaling during packing. + return bli_obj_buffer_for_1x1( dt_p, &BLIS_ONE ); + } +} + diff --git a/frame/3/her2k/bli_her2k.h b/frame/1m/packm/bli_packm_scalar.h similarity index 96% rename from frame/3/her2k/bli_her2k.h rename to frame/1m/packm/bli_packm_scalar.h index 02975c2b51..3745accf9d 100644 --- a/frame/3/her2k/bli_her2k.h +++ b/frame/1m/packm/bli_packm_scalar.h @@ -32,5 +32,5 @@ */ -#include "bli_her2k_front.h" +BLIS_EXPORT_BLIS void* bli_packm_scalar( obj_t* kappa, obj_t* p ); diff --git a/frame/1m/packm/bli_packm_struc_cxk.c b/frame/1m/packm/bli_packm_struc_cxk.c index b86a9ebbd0..2a52c42def 100644 --- a/frame/1m/packm/bli_packm_struc_cxk.c +++ b/frame/1m/packm/bli_packm_struc_cxk.c @@ -40,57 +40,24 @@ void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffc, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ inc_t is_p, \ cntx_t* cntx \ ) \ { \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ /* Handle micro-panel packing based on the structure of the matrix being packed. */ \ if ( bli_is_general( strucc ) ) \ @@ -100,6 +67,7 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,kername) \ ( \ conjc, \ + schema, \ panel_dim, \ panel_dim_max, \ panel_len, \ @@ -117,23 +85,21 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,packm_herm_cxk) \ ( \ strucc, \ - diagoffc, \ + diagc, \ uploc, \ conjc, \ schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ + invdiag, \ panel_dim, \ - panel_dim_max, \ panel_len, \ + panel_dim_max, \ panel_len_max, \ + panel_dim_off, \ + panel_len_off, \ kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ + c, incc, ldc, \ + p, ldp, \ + is_p, \ cntx \ ); \ } \ @@ -144,130 +110,24 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,packm_tri_cxk) \ ( \ strucc, \ - diagoffc, \ diagc, \ uploc, \ conjc, \ schema, \ invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ panel_dim, \ - panel_dim_max, \ panel_len, \ + panel_dim_max, \ panel_len_max, \ + panel_dim_off, \ + panel_len_off, \ kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ + c, incc, ldc, \ + p, ldp, \ + is_p, \ cntx \ ); \ } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -\ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype* restrict zero = PASTEMAC(ch,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype* p_edge = p + (i )*rs_p; \ -\ - PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero, \ - p_edge, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ - if ( n_panel != n_panel_max ) \ - { \ - ctype* restrict zero = PASTEMAC(ch,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype* p_edge = p + (j )*cs_p; \ -\ - PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero, \ - p_edge, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - ctype* restrict one = PASTEMAC(ch,1); \ - dim_t i = m_panel; \ - dim_t j = n_panel; \ - dim_t m_br = m_panel_max - i; \ - dim_t n_br = n_panel_max - j; \ - ctype* p_br = p + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC2(ch,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - one, \ - p_br, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ -\ -/* - if ( bli_is_col_packed( schema ) ) \ - PASTEMAC(ch,fprintm)( stdout, "packm_struc_cxk: bp copied", m_panel_max, n_panel_max, \ - p, rs_p, cs_p, "%4.1f", "" ); \ - else if ( bli_is_row_packed( schema ) ) \ - PASTEMAC(ch,fprintm)( stdout, "packm_struc_cxk: ap copied", m_panel_max, n_panel_max, \ - p, rs_p, cs_p, "%4.1f", "" ); \ -*/ \ } INSERT_GENTFUNC_BASIC( packm_struc_cxk, packm_cxk ) @@ -281,42 +141,31 @@ INSERT_GENTFUNC_BASIC( packm_struc_cxk, packm_cxk ) void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffc, \ + diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ dim_t panel_dim, \ - dim_t panel_dim_max, \ dim_t panel_len, \ + dim_t panel_dim_max, \ dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ + inc_t is_p, \ cntx_t* cntx \ ) \ { \ - doff_t diagoffc_abs; \ - dim_t i, j; \ - bool_t row_stored; \ - bool_t col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ + doff_t diagoffc = panel_dim_off - panel_len_off; \ + doff_t diagoffc_abs; \ + dim_t i, j; \ \ /* Handle the case where the micro-panel does NOT intersect the diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ + if ( !bli_intersects_diag_n( diagoffc, panel_dim, panel_len ) ) \ { \ /* If the current panel is unstored, we need to make a few adjustments so we refer to the data where it is actually @@ -324,10 +173,10 @@ void PASTEMAC(ch,varname) \ implicitly assumes we are operating on a dense panel within a larger symmetric or Hermitian matrix, since a general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ + if ( bli_is_unstored_subpart_n( diagoffc, uploc, panel_dim, panel_len ) ) \ { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ + c = c + diagoffc * ( doff_t )ldc + \ + -diagoffc * ( doff_t )incc; \ bli_swap_incs( &incc, &ldc ); \ \ if ( bli_is_hermitian( strucc ) ) \ @@ -338,6 +187,7 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,kername) \ ( \ conjc, \ + schema, \ panel_dim, \ panel_dim_max, \ panel_len, \ @@ -348,7 +198,7 @@ void PASTEMAC(ch,varname) \ cntx \ ); \ } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ + else /* if ( bli_intersects_diag_n( diagoffc, panel_dim, panel_len ) ) */ \ { \ ctype* restrict c10; \ ctype* restrict p10; \ @@ -368,14 +218,12 @@ void PASTEMAC(ch,varname) \ a micro-panel. If they do, then somehow the constraints on cache blocksizes being a whole multiple of the register blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ + if ( diagoffc < 0 ) \ bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ \ diagoffc_abs = bli_abs( diagoffc ); \ \ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ + if ( bli_is_lower( uploc ) ) \ { \ p10_dim = panel_dim; \ p10_len = diagoffc_abs; \ @@ -391,8 +239,8 @@ void PASTEMAC(ch,varname) \ diagoffc12 = diagoffc_abs - j; \ p12 = p + (j )*ldp; \ c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ + c12 = c12 + diagoffc12 * ( doff_t )ldc + \ + -diagoffc12 * ( doff_t )incc; \ incc12 = ldc; \ ldc12 = incc; \ conjc12 = conjc; \ @@ -400,16 +248,15 @@ void PASTEMAC(ch,varname) \ if ( bli_is_hermitian( strucc ) ) \ bli_toggle_conj( &conjc12 ); \ } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ + else /* if ( bli_is_upper( uploc ) ) */ \ { \ p10_dim = panel_dim; \ p10_len = diagoffc_abs + panel_dim; \ diagoffc10 = diagoffc; \ p10 = p; \ c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ + c10 = c10 + diagoffc10 * ( doff_t )ldc + \ + -diagoffc10 * ( doff_t )incc; \ incc10 = ldc; \ ldc10 = incc; \ conjc10 = conjc; \ @@ -436,6 +283,7 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,kername) \ ( \ conjc10, \ + schema, \ p10_dim, \ panel_dim_max, \ p10_len, \ @@ -455,6 +303,7 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,kername) \ ( \ conjc12, \ + schema, \ p12_dim, \ panel_dim_max, \ p12_len, \ @@ -482,8 +331,8 @@ void PASTEMAC(ch,varname) \ transc, \ p11_m, \ p11_n, \ - c11, rs_c, cs_c, \ - p11, rs_p, cs_p, \ + c11, incc, ldc, \ + p11, 1, ldp, \ cntx, \ NULL \ ); \ @@ -499,7 +348,7 @@ void PASTEMAC(ch,varname) \ { \ PASTEMAC(ch,seti0s)( *pi11 ); \ \ - pi11 += rs_p + cs_p; \ + pi11 += 1 + ldp; \ } \ } \ \ @@ -515,7 +364,7 @@ void PASTEMAC(ch,varname) \ p11_m, \ p11_n, \ kappa, \ - p11, rs_p, cs_p, \ + p11, 1, ldp, \ cntx, \ NULL \ ); \ @@ -535,32 +384,31 @@ INSERT_GENTFUNC_BASIC( packm_herm_cxk, packm_cxk ) void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffp, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ dim_t panel_dim, \ - dim_t panel_dim_max, \ dim_t panel_len, \ + dim_t panel_dim_max, \ dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ + inc_t is_p, \ cntx_t* cntx \ ) \ { \ + doff_t diagoffc = panel_dim_off - panel_len_off; \ +\ /* Pack the panel. */ \ PASTEMAC(ch,kername) \ ( \ conjc, \ + schema, \ panel_dim, \ panel_dim_max, \ panel_len, \ @@ -579,11 +427,11 @@ void PASTEMAC(ch,varname) \ PASTEMAC2(ch,setd,BLIS_TAPI_EX_SUF) \ ( \ BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ + diagoffc, \ + panel_dim, \ + panel_len, \ kappa, \ - p, rs_p, cs_p, \ + p, 1, ldp, \ cntx, \ NULL \ ); \ @@ -594,10 +442,10 @@ void PASTEMAC(ch,varname) \ { \ PASTEMAC2(ch,invertd,BLIS_TAPI_EX_SUF) \ ( \ - diagoffp, \ - m_panel, \ - n_panel, \ - p, rs_p, cs_p, \ + diagoffc, \ + panel_dim, \ + panel_len, \ + p, 1, ldp, \ cntx, \ NULL \ ); \ @@ -616,23 +464,53 @@ void PASTEMAC(ch,varname) \ uplo_t uplop = uploc; \ \ bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ + bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffc ); \ \ PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ ( \ BLIS_NO_CONJUGATE, \ - diagoffp, \ + diagoffc, \ BLIS_NONUNIT_DIAG, \ uplop, \ - m_panel, \ - n_panel, \ + panel_dim, \ + panel_len, \ zero, \ - p, rs_p, cs_p, \ + p, 1, ldp, \ cntx, \ NULL \ ); \ } \ \ + /* If this panel is an edge case in both panel dimension and length, + then it must be a bottom-right corner case. Set the part of the + diagonal that extends into the zero-padded region to identity. + NOTE: This is actually only necessary when packing for trsm, as + it helps prevent NaNs and Infs from creeping into the computation. + However, we set the region to identity for trmm as well. Those + 1.0's end up getting muliplied by the 0.0's in the zero-padded + region of the other matrix, so there is no harm in this. */ \ + if ( panel_dim != panel_dim_max && \ + panel_len != panel_len_max ) \ + { \ + ctype* restrict one = PASTEMAC(ch,1); \ + dim_t i = panel_dim; \ + dim_t j = panel_len; \ + dim_t m_br = panel_dim_max - i; \ + dim_t n_br = panel_len_max - j; \ + ctype* p_br = p + (i ) + (j )*ldp; \ +\ + PASTEMAC2(ch,setd,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + m_br, \ + n_br, \ + one, \ + p_br, 1, ldp, \ + cntx, \ + NULL \ + ); \ + } \ } INSERT_GENTFUNC_BASIC( packm_tri_cxk, packm_cxk ) diff --git a/frame/1m/packm/bli_packm_struc_cxk.h b/frame/1m/packm/bli_packm_struc_cxk.h index 5b486d776a..973a02612b 100644 --- a/frame/1m/packm/bli_packm_struc_cxk.h +++ b/frame/1m/packm/bli_packm_struc_cxk.h @@ -38,84 +38,25 @@ void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffp, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROT_BASIC0( packm_struc_cxk ) - - - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ dim_t panel_dim, \ - dim_t panel_dim_max, \ dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROT_BASIC0( packm_herm_cxk ) - - - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ dim_t panel_dim_max, \ - dim_t panel_len, \ dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ + inc_t is_p, \ cntx_t* cntx \ ); +INSERT_GENTPROT_BASIC0( packm_struc_cxk ) +INSERT_GENTPROT_BASIC0( packm_herm_cxk ) INSERT_GENTPROT_BASIC0( packm_tri_cxk ) diff --git a/frame/1m/packm/bli_packm_struc_cxk_1er.c b/frame/1m/packm/bli_packm_struc_cxk_1er.c index 038ee1b8f5..b3be9dff95 100644 --- a/frame/1m/packm/bli_packm_struc_cxk_1er.c +++ b/frame/1m/packm/bli_packm_struc_cxk_1er.c @@ -40,57 +40,25 @@ void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffc, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ inc_t is_p, \ - cntx_t* cntx \ + cntx_t* cntx, \ + void* params \ ) \ { \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ /* Handle micro-panel packing based on the structure of the matrix being packed. */ \ if ( bli_is_general( strucc ) ) \ @@ -108,7 +76,7 @@ void PASTEMAC(ch,varname) \ kappa, \ c, incc, ldc, \ p, ldp, \ - cntx \ + cntx \ ); \ } \ else if ( bli_is_herm_or_symm( strucc ) ) \ @@ -118,24 +86,23 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,packm_herm_cxk_1er) \ ( \ strucc, \ - diagoffc, \ + diagc, \ uploc, \ conjc, \ schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ + invdiag, \ panel_dim, \ - panel_dim_max, \ panel_len, \ + panel_dim_max, \ panel_len_max, \ + panel_dim_off, \ + panel_len_off, \ kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ - cntx \ + c, incc, ldc, \ + p, ldp, \ + is_p, \ + cntx, \ + params \ ); \ } \ else /* ( bli_is_triangular( strucc ) ) */ \ @@ -145,125 +112,25 @@ void PASTEMAC(ch,varname) \ PASTEMAC(ch,packm_tri_cxk_1er) \ ( \ strucc, \ - diagoffc, \ diagc, \ uploc, \ conjc, \ schema, \ invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ panel_dim, \ - panel_dim_max, \ panel_len, \ + panel_dim_max, \ panel_len_max, \ + panel_dim_off, \ + panel_len_off, \ kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype* restrict zero = PASTEMAC(ch,0); \ - dim_t offm = m_panel; \ - dim_t offn = 0; \ - dim_t m_edge = m_panel_max - m_panel; \ - dim_t n_edge = n_panel_max; \ -\ - PASTEMAC(ch,set1ms_mxn) \ - ( \ - schema, \ - offm, \ - offn, \ - m_edge, \ - n_edge, \ - zero, \ - p, rs_p, cs_p, ldp \ - ); \ - } \ -\ - if ( n_panel != n_panel_max ) \ - { \ - ctype* restrict zero = PASTEMAC(ch,0); \ - dim_t offm = 0; \ - dim_t offn = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - n_panel; \ -\ - PASTEMAC(ch,set1ms_mxn) \ - ( \ - schema, \ - offm, \ - offn, \ - m_edge, \ - n_edge, \ - zero, \ - p, rs_p, cs_p, ldp \ + c, incc, ldc, \ + p, ldp, \ + is_p, \ + cntx, \ + params \ ); \ } \ -*/ \ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this micro-panel is an edge case in both panel dimension and - length, then it must be a bottom-right corner case, which - typically only happens for micro-panels being packed for trsm. - (It also happens for trmm if kr > 1.) Here, we set the part of - the diagonal that extends into the zero-padded region to - identity. This prevents NaNs and Infs from creeping into the - computation. If this code does execute for trmm, it is okay, - because those 1.0's that extend into the bottom-right region - end up getting muliplied by the 0.0's in the zero-padded region - of the other matrix. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - ctype* restrict one = PASTEMAC(ch,1); \ - dim_t offm = m_panel; \ - dim_t offn = n_panel; \ - dim_t m_edge = m_panel_max - m_panel; \ - dim_t n_edge = n_panel_max - n_panel; \ -\ - PASTEMAC(ch,set1ms_mxn_diag) \ - ( \ - schema, \ - offm, \ - offn, \ - m_edge, \ - n_edge, \ - one, \ - p, rs_p, cs_p, ldp \ - ); \ - } \ - } \ -\ -\ -/* - if ( bli_is_1r_packed( schema ) ) { \ - PASTEMAC(chr,fprintm)( stdout, "packm_struc_cxk_1er (1r): bp", m_panel_max, 2*n_panel_max, \ - ( ctype_r* )p, rs_p, cs_p, "%4.1f", "" ); \ - } \ - \ - if ( bli_is_1e_packed( schema ) ) { \ - PASTEMAC(chr,fprintm)( stdout, "packm_struc_cxk_1er (1e): ap", 2*m_panel_max, 2*n_panel_max, \ - ( ctype_r* )p, rs_p, cs_p, "%4.1f", "" ); \ - } \ -*/ \ } INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_1er, packm_cxk_1er ) @@ -277,42 +144,32 @@ INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_1er, packm_cxk_1er ) void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffc, \ + diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ dim_t panel_dim, \ - dim_t panel_dim_max, \ dim_t panel_len, \ + dim_t panel_dim_max, \ dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ + inc_t is_p, \ + cntx_t* cntx, \ + void* params \ ) \ { \ - doff_t diagoffc_abs; \ - dim_t j; \ - bool_t row_stored; \ - bool_t col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ + doff_t diagoffc = panel_dim_off - panel_len_off; \ + doff_t diagoffc_abs; \ + dim_t j; \ \ /* Handle the case where the micro-panel does NOT intersect the diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ + if ( !bli_intersects_diag_n( diagoffc, panel_dim, panel_len ) ) \ { \ /* If the current panel is unstored, we need to make a few adjustments so we refer to the data where it is actually @@ -320,10 +177,10 @@ void PASTEMAC(ch,varname) \ implicitly assumes we are operating on a dense panel within a larger symmetric or Hermitian matrix, since a general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ + if ( bli_is_unstored_subpart_n( diagoffc, uploc, panel_dim, panel_len ) ) \ { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ + c = c + diagoffc * ( doff_t )ldc + \ + -diagoffc * ( doff_t )incc; \ bli_swap_incs( &incc, &ldc ); \ \ if ( bli_is_hermitian( strucc ) ) \ @@ -345,7 +202,7 @@ void PASTEMAC(ch,varname) \ cntx \ ); \ } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ + else /* if ( bli_intersects_diag_n( diagoffc, panel_dim, panel_len ) ) */ \ { \ ctype* restrict c10; \ ctype* restrict p10; \ @@ -366,14 +223,12 @@ void PASTEMAC(ch,varname) \ a micro-panel. If they do, then somehow the constraints on cache blocksizes being a whole multiple of the register blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ + if ( diagoffc < 0 ) \ bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ \ diagoffc_abs = bli_abs( diagoffc ); \ \ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ + if ( bli_is_lower( uploc ) ) \ { \ p10_dim = panel_dim; \ p10_len = diagoffc_abs; \ @@ -389,8 +244,8 @@ void PASTEMAC(ch,varname) \ diagoffc12 = diagoffc_abs - j; \ p12 = p + (j )*ldp; \ c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ + c12 = c12 + diagoffc12 * ( doff_t )ldc + \ + -diagoffc12 * ( doff_t )incc; \ incc12 = ldc; \ ldc12 = incc; \ conjc12 = conjc; \ @@ -398,16 +253,15 @@ void PASTEMAC(ch,varname) \ if ( bli_is_hermitian( strucc ) ) \ bli_toggle_conj( &conjc12 ); \ } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ + else /* if ( bli_is_upper( uploc ) ) */ \ { \ p10_dim = panel_dim; \ p10_len = diagoffc_abs + panel_dim; \ diagoffc10 = diagoffc; \ p10 = p; \ c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ + c10 = c10 + diagoffc10 * ( doff_t )ldc + \ + -diagoffc10 * ( doff_t )incc; \ incc10 = ldc; \ ldc10 = incc; \ conjc10 = conjc; \ @@ -478,8 +332,8 @@ void PASTEMAC(ch,varname) \ conjc, \ panel_dim, \ kappa, \ - c11, rs_c, cs_c, \ - p11, rs_p, cs_p, ldp \ + c11, incc, ldc, \ + p11, 1, ldp, ldp \ ); \ \ /* If we are packing a micro-panel with Hermitian structure, @@ -495,8 +349,8 @@ void PASTEMAC(ch,varname) \ if ( bli_is_hermitian( strucc ) ) \ { \ ctype_r* restrict c11_r = ( ctype_r* )c11; \ - const dim_t rs_c2 = 2*rs_c; \ - const dim_t cs_c2 = 2*cs_c; \ + const dim_t incc2 = 2*incc; \ + const dim_t ldc2 = 2*ldc; \ \ PASTEMAC3(ch,chr,ch,scal21ms_mxn_diag) \ ( \ @@ -504,8 +358,8 @@ void PASTEMAC(ch,varname) \ panel_dim, \ panel_dim, \ kappa, \ - c11_r, rs_c2, cs_c2, \ - p11, rs_p, cs_p, ldp \ + c11_r, incc2, ldc2, \ + p11, 1, ldp, ldp \ ); \ } \ } \ @@ -523,30 +377,28 @@ INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_1er, packm_cxk_1er ) void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffp, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ dim_t panel_dim, \ - dim_t panel_dim_max, \ dim_t panel_len, \ + dim_t panel_dim_max, \ dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ + inc_t is_p, \ + cntx_t* cntx, \ + void* params \ ) \ { \ - doff_t diagoffp_abs = bli_abs( diagoffp ); \ - ctype* p11 = p + (diagoffp_abs )*ldp; \ + doff_t diagoffc = panel_dim_off - panel_len_off; \ + doff_t diagoffc_abs = bli_abs( diagoffc ); \ + ctype* p11 = p + (diagoffc_abs )*ldp; \ \ \ /* Pack the panel. */ \ @@ -579,7 +431,7 @@ void PASTEMAC(ch,varname) \ panel_dim, \ panel_dim, \ kappa, \ - p11, rs_p, cs_p, ldp \ + p11, 1, ldp, ldp \ ); \ } \ \ @@ -594,7 +446,7 @@ void PASTEMAC(ch,varname) \ 0, \ panel_dim, \ panel_dim, \ - p11, rs_p, cs_p, ldp \ + p11, 1, ldp, ldp \ ); \ } \ \ @@ -610,11 +462,11 @@ void PASTEMAC(ch,varname) \ { \ ctype* restrict zero = PASTEMAC(ch,0); \ uplo_t uplop = uploc; \ - doff_t diagoffp11_0 = 0; \ + doff_t diagoffc11_0 = 0; \ dim_t p11_0_dim = panel_dim - 1; \ \ bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp11_0 ); \ + bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffc11_0 ); \ \ /* Note that this macro works a little differently than the setm operation. Here, we pass in the dimensions of only p11, rather @@ -622,20 +474,51 @@ void PASTEMAC(ch,varname) \ "shrunken" dimensions of p11, corresponding to the toggling and shrinking of the diagonal above. The macro will do the right thing, incrementing the pointer to p11 by the appropriate - leading dimension (cs_p or rs_p), and setting only the lower + leading dimension (ldp or rs_p), and setting only the lower or upper triangle to zero. */ \ PASTEMAC(ch,set1ms_mxn_uplo) \ ( \ schema, \ - diagoffp11_0, \ + diagoffc11_0, \ uplop, \ p11_0_dim, \ p11_0_dim, \ zero, \ - p11, rs_p, cs_p, ldp \ + p11, 1, ldp, ldp \ ); \ } \ } \ +\ + /* If this micro-panel is an edge case in both panel dimension and + length, then it must be a bottom-right corner case, which + typically only happens for micro-panels being packed for trsm. + (It also happens for trmm if kr > 1.) Here, we set the part of + the diagonal that extends into the zero-padded region to + identity. This prevents NaNs and Infs from creeping into the + computation. If this code does execute for trmm, it is okay, + because those 1.0's that extend into the bottom-right region + end up getting muliplied by the 0.0's in the zero-padded region + of the other matrix. */ \ + if ( panel_dim != panel_dim_max && \ + panel_len != panel_len_max ) \ + { \ + ctype* restrict one = PASTEMAC(ch,1); \ + dim_t offm = panel_dim; \ + dim_t offn = panel_len; \ + dim_t m_edge = panel_dim_max - panel_dim; \ + dim_t n_edge = panel_len_max - panel_len; \ +\ + PASTEMAC(ch,set1ms_mxn_diag) \ + ( \ + schema, \ + offm, \ + offn, \ + m_edge, \ + n_edge, \ + one, \ + p, 1, ldp, ldp \ + ); \ + } \ } INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_1er, packm_cxk_1er ) diff --git a/frame/1m/packm/bli_packm_struc_cxk_1er.h b/frame/1m/packm/bli_packm_struc_cxk_1er.h index e63edf8f28..a953e93673 100644 --- a/frame/1m/packm/bli_packm_struc_cxk_1er.h +++ b/frame/1m/packm/bli_packm_struc_cxk_1er.h @@ -38,84 +38,26 @@ void PASTEMAC(ch,varname) \ ( \ struc_t strucc, \ - doff_t diagoffp, \ diag_t diagc, \ uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_1er ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ dim_t panel_dim, \ - dim_t panel_dim_max, \ dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_1er ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ dim_t panel_dim_max, \ - dim_t panel_len, \ dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ + ctype* restrict c, inc_t incc, inc_t ldc, \ + ctype* restrict p, inc_t ldp, \ + inc_t is_p, \ + cntx_t* cntx, \ + void* params \ ); +INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_1er ) +INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_1er ) INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_1er ) diff --git a/frame/1m/packm/bli_packm_struc_cxk_3mis.c b/frame/1m/packm/bli_packm_struc_cxk_3mis.c deleted file mode 100644 index 9d01b3b409..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_3mis.c +++ /dev/null @@ -1,842 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ) \ -{ \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ - /* Handle micro-panel packing based on the structure of the matrix - being packed. */ \ - if ( bli_is_general( strucc ) ) \ - { \ - /* For micro-panels of general matrices, we can call the pack - kernel front-end directly. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* Call a helper function for micro-panels of Hermitian/symmetric - matrices. */ \ - PASTEMAC(ch,packm_herm_cxk_3mis) \ - ( \ - strucc, \ - diagoffc, \ - uploc, \ - conjc, \ - schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ - else /* ( bli_is_triangular( strucc ) ) */ \ - { \ - /* Call a helper function for micro-panels of triangular - matrices. */ \ - PASTEMAC(ch,packm_tri_cxk_3mis) \ - ( \ - strucc, \ - diagoffc, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*rs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*rs_p; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*rs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -/* - if ( n_panel != n_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*cs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*cs_p; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t j = n_panel; \ - dim_t m_br = m_panel_max - i; \ - dim_t n_br = n_panel_max - j; \ - ctype_r* p_br_r = ( ctype_r* )p + (i )*rs_p + (j )*cs_p; \ - ctype_r* p_br_i = ( ctype_r* )p + is_p + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - one_r, \ - p_br_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - zero_r, \ - p_br_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_3mis, packm_cxk_3mis ) - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - doff_t diagoffc_abs; \ - dim_t i, j; \ - bool_t row_stored; \ - bool_t col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ - /* Handle the case where the micro-panel does NOT intersect the - diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ - { \ - /* If the current panel is unstored, we need to make a few - adjustments so we refer to the data where it is actually - stored, also taking conjugation into account. (Note this - implicitly assumes we are operating on a dense panel - within a larger symmetric or Hermitian matrix, since a - general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ - { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_swap_incs( &incc, &ldc ); \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc ); \ - } \ -\ - /* Pack the full panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ - { \ - ctype_r* restrict p_r = ( ctype_r* )p; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype* restrict c10; \ - ctype_r* restrict p10; \ - dim_t p10_dim, p10_len; \ - inc_t incc10, ldc10; \ - doff_t diagoffc10; \ - conj_t conjc10; \ -\ - ctype* restrict c12; \ - ctype_r* restrict p12; \ - dim_t p12_dim, p12_len; \ - inc_t incc12, ldc12; \ - doff_t diagoffc12; \ - conj_t conjc12; \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - diagoffc_abs = bli_abs( diagoffc ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs; \ - p10 = p_r; \ - c10 = c; \ - incc10 = incc; \ - ldc10 = ldc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - diagoffc12 = diagoffc_abs - j; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ - incc12 = ldc; \ - ldc12 = incc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc12 ); \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs + panel_dim; \ - diagoffc10 = diagoffc; \ - p10 = p_r; \ - c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ - incc10 = ldc; \ - ldc10 = incc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - incc12 = incc; \ - ldc12 = ldc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc10 ); \ - } \ -\ - /* Pack to p10. For upper storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc10, \ - p10_dim, \ - panel_dim_max, \ - p10_len, \ - p10_len, \ - kappa, \ - c10, incc10, ldc10, \ - ( ctype* )p10, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack to p12. For lower storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc12, \ - p12_dim, \ - panel_dim_max, \ - p12_len, \ - p12_len, \ - kappa, \ - c12, incc12, ldc12, \ - ( ctype* )p12, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack the stored triangle of c11 to p11. */ \ - { \ - dim_t p11_m = panel_dim; \ - dim_t p11_n = panel_dim; \ - inc_t rs_c11 = 2*rs_c; \ - inc_t cs_c11 = 2*cs_c; \ - dim_t j2 = diagoffc_abs; \ - ctype* c11 = ( ctype* )c + (j2 )*ldc; \ - ctype_r* p11 = ( ctype_r* )p_r + (j2 )*ldp; \ - ctype_r* c11_r = ( ctype_r* )c11; \ - ctype_r* c11_i = ( ctype_r* )c11 + 1; \ - ctype_r* p11_r = ( ctype_r* )p11; \ - ctype_r* p11_i = ( ctype_r* )p11 + is_p; \ - ctype_r* alpha_r = one_r; \ - ctype_r* alpha_i = ( bli_is_conj( conjc ) ? minus_one_r : one_r ); \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ -\ - /* Copy the real part of the stored triangle of c11 to p11_r. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_r, \ - c11_r, rs_c11, cs_c11, \ - p11_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* Copy the imaginary part of the stored triangle of c11 to p11_i, - scaling by -1 if conjugation on c was requested. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_i, \ - c11_i, rs_c11, cs_c11, \ - p11_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* If source matrix c is Hermitian, we have to zero out the - imaginary components of the diagonal of p11 in case the - corresponding elements in c11 were not already zero. */ \ - if ( bli_is_hermitian( strucc ) ) \ - { \ - for ( i = 0; i < p11_m; ++i ) \ - { \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(chr,set0s)( *pi11_i ); \ - } \ - } \ -\ - /* Apply kappa to the part of p11 that corresponds to the stored - part of c11 that was copied above. */ \ - if ( bli_is_upper( uploc ) ) \ - { \ - PASTEMAC(ch,scalris_mxn_u) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,scalris_mxn_l) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ -\ - /* Update the p11 section of the ri panel. It simply needs - to contain the sum of p11_r + p11_i. */ \ - { \ - ctype_r* p11_rpi = p11_i + is_p; \ -\ - for ( j = 0; j < p11_n; ++j ) \ - for ( i = 0; i < p11_m; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (j )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (j )*cs_p; \ - ctype_r* pi11_rpi = p11_rpi + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC(chr,add3s) \ - ( \ - *pi11_r, \ - *pi11_i, \ - *pi11_rpi \ - ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_3mis, packm_cxk_3mis ) - - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Pack the panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ -\ -\ - /* Tweak the panel according to its triangular structure */ \ - { \ - ctype_r* p_r = ( ctype_r* )p + 0; \ - ctype_r* p_i = ( ctype_r* )p + is_p; \ - ctype_r* p_rpi = ( ctype_r* )p + 2*is_p; \ -\ - dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = p_r + (j )*ldp; \ - ctype_r* p11_i = p_i + (j )*ldp; \ - ctype_r* p11_rpi = p_rpi + (j )*ldp; \ -\ - dim_t p11_m = m_panel; \ - dim_t p11_n = n_panel; \ -\ - dim_t min_p11_m_n; \ -\ - if ( diagoffp < 0 ) p11_m -= j; \ - else if ( diagoffp > 0 ) p11_n -= j; \ -\ - min_p11_m_n = bli_min( p11_m, p11_n ); \ -\ -\ - /* If the diagonal of c is implicitly unit, explicitly set the - the diagonal of the packed panel to kappa. */ \ - if ( bli_is_unit_diag( diagc ) ) \ - { \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ - dim_t i; \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_i, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* Update the diagonal of the p11 section of the rpi panel. - It simply needs to contain the sum of diagonals of p11_r - and p11_i. */ \ - for ( i = 0; i < min_p11_m_n; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_rpi = p11_rpi + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(chr,add3s)( *pi11_r, *pi11_i, *pi11_rpi ); \ - } \ - } \ -\ - /* If requested, invert the diagonal of the packed panel. Note - that we do not need to update the ri panel since inverted - diagonals are only needed by trsm, which does not use the - p11 section of the ri panel. */ \ - if ( invdiag == TRUE ) \ - { \ - dim_t i; \ -\ - for ( i = 0; i < min_p11_m_n; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(ch,invertris)( *pi11_r, *pi11_i ); \ - } \ - } \ -\ - /* Set the region opposite the diagonal of p to zero. To do this, - we need to reference the "unstored" region on the other side of - the diagonal. This amounts to toggling uploc and then shifting - the diagonal offset to shrink the newly referenced region (by - one diagonal). Note that this zero-filling is not needed for - trsm, since the unstored region is not referenced by the trsm - micro-kernel; however, zero-filling is needed for trmm, which - uses the gemm micro-kernel.*/ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop = uploc; \ -\ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_rpi, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_3mis, packm_cxk_3mis ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_3mis.h b/frame/1m/packm/bli_packm_struc_cxk_3mis.h deleted file mode 100644 index 24f2c0fcbd..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_3mis.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_3mis ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_3mis ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_3mis ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_4mi.c b/frame/1m/packm/bli_packm_struc_cxk_4mi.c deleted file mode 100644 index 3df849921d..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_4mi.c +++ /dev/null @@ -1,757 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ) \ -{ \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ - /* Handle micro-panel packing based on the structure of the matrix - being packed. */ \ - if ( bli_is_general( strucc ) ) \ - { \ - /* For micro-panels of general matrices, we can call the pack - kernel front-end directly. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* Call a helper function for micro-panels of Hermitian/symmetric - matrices. */ \ - PASTEMAC(ch,packm_herm_cxk_4mi) \ - ( \ - strucc, \ - diagoffc, \ - uploc, \ - conjc, \ - schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ - else /* ( bli_is_triangular( strucc ) ) */ \ - { \ - /* Call a helper function for micro-panels of triangular - matrices. */ \ - PASTEMAC(ch,packm_tri_cxk_4mi) \ - ( \ - strucc, \ - diagoffc, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*rs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*rs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ - if ( n_panel != n_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*cs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t j = n_panel; \ - dim_t m_br = m_panel_max - i; \ - dim_t n_br = n_panel_max - j; \ - ctype_r* p_br_r = ( ctype_r* )p + (i )*rs_p + (j )*cs_p; \ - ctype_r* p_br_i = ( ctype_r* )p + is_p + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - one_r, \ - p_br_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - zero_r, \ - p_br_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_4mi, packm_cxk_4mi ) - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - doff_t diagoffc_abs; \ - dim_t i, j; \ - bool_t row_stored; \ - bool_t col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ - /* Handle the case where the micro-panel does NOT intersect the - diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ - { \ - /* If the current panel is unstored, we need to make a few - adjustments so we refer to the data where it is actually - stored, also taking conjugation into account. (Note this - implicitly assumes we are operating on a dense panel - within a larger symmetric or Hermitian matrix, since a - general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ - { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_swap_incs( &incc, &ldc ); \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc ); \ - } \ -\ - /* Pack the full panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ - { \ - ctype_r* restrict p_r = ( ctype_r* )p; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype* restrict c10; \ - ctype_r* restrict p10; \ - dim_t p10_dim, p10_len; \ - inc_t incc10, ldc10; \ - doff_t diagoffc10; \ - conj_t conjc10; \ -\ - ctype* restrict c12; \ - ctype_r* restrict p12; \ - dim_t p12_dim, p12_len; \ - inc_t incc12, ldc12; \ - doff_t diagoffc12; \ - conj_t conjc12; \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - diagoffc_abs = bli_abs( diagoffc ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs; \ - p10 = p_r; \ - c10 = c; \ - incc10 = incc; \ - ldc10 = ldc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - diagoffc12 = diagoffc_abs - j; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ - incc12 = ldc; \ - ldc12 = incc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc12 ); \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs + panel_dim; \ - diagoffc10 = diagoffc; \ - p10 = p_r; \ - c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ - incc10 = ldc; \ - ldc10 = incc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - incc12 = incc; \ - ldc12 = ldc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc10 ); \ - } \ -\ - /* Pack to p10. For upper storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc10, \ - p10_dim, \ - panel_dim_max, \ - p10_len, \ - p10_len, \ - kappa, \ - c10, incc10, ldc10, \ - ( ctype* )p10, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack to p12. For lower storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc12, \ - p12_dim, \ - panel_dim_max, \ - p12_len, \ - p12_len, \ - kappa, \ - c12, incc12, ldc12, \ - ( ctype* )p12, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack the stored triangle of c11 to p11. */ \ - { \ - dim_t p11_m = panel_dim; \ - dim_t p11_n = panel_dim; \ - inc_t rs_c11 = 2*rs_c; \ - inc_t cs_c11 = 2*cs_c; \ - dim_t j2 = diagoffc_abs; \ - ctype* c11 = ( ctype* )c + (j2 )*ldc; \ - ctype_r* p11 = ( ctype_r* )p_r + (j2 )*ldp; \ - ctype_r* c11_r = ( ctype_r* )c11; \ - ctype_r* c11_i = ( ctype_r* )c11 + 1; \ - ctype_r* p11_r = ( ctype_r* )p11; \ - ctype_r* p11_i = ( ctype_r* )p11 + is_p; \ - ctype_r* alpha_r = one_r; \ - ctype_r* alpha_i = ( bli_is_conj( conjc ) ? minus_one_r : one_r ); \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ -\ - /* Copy the real part of the stored triangle of c11 to p11_r. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_r, \ - c11_r, rs_c11, cs_c11, \ - p11_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* Copy the imaginary part of the stored triangle of c11 to p11_i, - scaling by -1 if conjugation on c was requested. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_i, \ - c11_i, rs_c11, cs_c11, \ - p11_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* If source matrix c is Hermitian, we have to zero out the - imaginary components of the diagonal of p11 in case the - corresponding elements in c11 were not already zero. */ \ - if ( bli_is_hermitian( strucc ) ) \ - { \ - for ( i = 0; i < p11_m; ++i ) \ - { \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(chr,set0s)( *pi11_i ); \ - } \ - } \ -\ - /* Apply kappa to the part of p11 that corresponds to the stored - part of c11 that was copied above. */ \ - if ( bli_is_upper( uploc ) ) \ - { \ - PASTEMAC(ch,scalris_mxn_u) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,scalris_mxn_l) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ -/* - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_r copied", m_panel_max, n_panel_max, \ - p_r + 0*is_p, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_i copied", m_panel_max, n_panel_max, \ - p_r + 1*is_p, rs_p, cs_p, "%4.1f", "" ); \ -*/ \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_4mi, packm_cxk_4mi ) - - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Pack the panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ -\ -\ - /* Tweak the panel according to its triangular structure */ \ - { \ - ctype_r* p_r = ( ctype_r* )p; \ - ctype_r* p_i = ( ctype_r* )p + is_p; \ -\ - dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = p_r + (j )*ldp; \ - ctype_r* p11_i = p_i + (j )*ldp; \ -\ - /* If the diagonal of c is implicitly unit, explicitly set the - the diagonal of the packed panel to kappa. */ \ - if ( bli_is_unit_diag( diagc ) ) \ - { \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_i, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ -\ - /* If requested, invert the diagonal of the packed panel. */ \ - if ( invdiag == TRUE ) \ - { \ - dim_t i; \ -\ - for ( i = 0; i < panel_dim; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(ch,invertris)( *pi11_r, *pi11_i ); \ - } \ - } \ -\ -\ - /* Set the region opposite the diagonal of p to zero. To do this, - we need to reference the "unstored" region on the other side of - the diagonal. This amounts to toggling uploc and then shifting - the diagonal offset to shrink the newly referenced region (by - one diagonal). Note that this zero-filling is not needed for - trsm, since the unstored region is not referenced by the trsm - micro-kernel; however, zero-filling is needed for trmm, which - uses the gemm micro-kernel.*/ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop = uploc; \ -\ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_4mi, packm_cxk_4mi ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_4mi.h b/frame/1m/packm/bli_packm_struc_cxk_4mi.h deleted file mode 100644 index f2e6636bfa..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_4mi.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_4mi ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_4mi ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_4mi ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_md.c b/frame/1m/packm/bli_packm_struc_cxk_md.c index 52a1f9817f..650b6178c9 100644 --- a/frame/1m/packm/bli_packm_struc_cxk_md.c +++ b/frame/1m/packm/bli_packm_struc_cxk_md.c @@ -41,53 +41,26 @@ \ void PASTEMAC2(chc,chp,varname) \ ( \ + struc_t strucc, \ + diag_t diagc, \ + uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype_p* restrict kappa, \ - ctype_c* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype_p* restrict p, inc_t rs_p, inc_t cs_p, \ + ctype_c* restrict c, inc_t incc, inc_t ldc, \ + ctype_p* restrict p, inc_t ldp, \ inc_t is_p, \ - cntx_t* cntx \ + cntx_t* cntx, \ + void* params \ ) \ { \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ if ( bli_is_nat_packed( schema ) ) \ { \ /* Sanity check: Make sure that kappa is 1.0. Mixed-datatype alpha @@ -318,7 +291,7 @@ void PASTEMAC2(cha,chp,opname) \ conj_t conja, \ dim_t m, \ dim_t n, \ - ctype_p* restrict kappa, \ + ctype_p* restrict kappa, \ ctype_a* restrict a, inc_t inca, inc_t lda, \ ctype_p* restrict p, inc_t ldp \ ) \ @@ -445,7 +418,7 @@ void PASTEMAC2(cha,chp,opname) \ conj_t conja, \ dim_t m, \ dim_t n, \ - ctype_p* restrict kappa, \ + ctype_p* restrict kappa, \ ctype_a* restrict a, inc_t inca, inc_t lda, \ ctype_p* restrict p, inc_t ldp \ ) \ diff --git a/frame/1m/packm/bli_packm_struc_cxk_md.h b/frame/1m/packm/bli_packm_struc_cxk_md.h index 72ca67937f..f493838b3a 100644 --- a/frame/1m/packm/bli_packm_struc_cxk_md.h +++ b/frame/1m/packm/bli_packm_struc_cxk_md.h @@ -37,17 +37,24 @@ \ void PASTEMAC2(chc,chp,varname) \ ( \ + struc_t strucc, \ + diag_t diagc, \ + uplo_t uploc, \ conj_t conjc, \ pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ ctype_p* restrict kappa, \ - ctype_c* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype_p* restrict p, inc_t rs_p, inc_t cs_p, \ + ctype_c* restrict c, inc_t incc, inc_t ldc, \ + ctype_p* restrict p, inc_t ldp, \ inc_t is_p, \ - cntx_t* cntx \ + cntx_t* cntx, \ + void* params \ ); INSERT_GENTPROT2_BASIC0( packm_struc_cxk_md ) diff --git a/frame/1m/packm/bli_packm_struc_cxk_rih.c b/frame/1m/packm/bli_packm_struc_cxk_rih.c deleted file mode 100644 index 32a7ec1a7e..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_rih.c +++ /dev/null @@ -1,625 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ) \ -{ \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ - /* Handle micro-panel packing based on the structure of the matrix - being packed. */ \ - if ( bli_is_general( strucc ) ) \ - { \ - /* For micro-panels of general matrices, we can call the pack - kernel front-end directly. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - schema, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, ldp, \ - cntx \ - ); \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* Call a helper function for micro-panels of Hermitian/symmetric - matrices. */ \ - PASTEMAC(ch,packm_herm_cxk_rih) \ - ( \ - strucc, \ - diagoffc, \ - uploc, \ - conjc, \ - schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ - cntx \ - ); \ - } \ - else /* ( bli_is_triangular( strucc ) ) */ \ - { \ - /* Call a helper function for micro-panels of triangular - matrices. */ \ - PASTEMAC(ch,packm_tri_cxk_rih) \ - ( \ - strucc, \ - diagoffc, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*rs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ - if ( n_panel != n_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - /* We don't need this case if we aren't supporting trsm. - Why? Because trmm's packm control tree node should be - using k dimension multiples of 1 (kr == 1), which means - there will never be zero padding at the far end of a - micro-panel. */ \ - } \ - } \ -\ -\ -/* - { \ - if ( bli_is_col_packed( schema ) ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_struc_cxk_rih: bp copied", m_panel_max, n_panel_max, \ - ( ctype_r* )p, rs_p, cs_p, "%4.1f", "" ); \ - else if ( bli_is_row_packed( schema ) ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_struc_cxk_rih: ap copied", m_panel_max, n_panel_max, \ - ( ctype_r* )p, rs_p, cs_p, "%4.1f", "" ); \ - } \ -*/ \ - \ -\ -} - -INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_rih, packm_cxk_rih ) - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - bool_t row_stored; \ - bool_t col_stored; \ - doff_t diagoffc_abs; \ - dim_t j; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ - /* Handle the case where the micro-panel does NOT intersect the - diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ - { \ - /* If the current panel is unstored, we need to make a few - adjustments so we refer to the data where it is actually - stored, also taking conjugation into account. (Note this - implicitly assumes we are operating on a dense panel - within a larger symmetric or Hermitian matrix, since a - general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ - { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_swap_incs( &incc, &ldc ); \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc ); \ - } \ -\ - /* Pack the full panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - schema, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, ldp, \ - cntx \ - ); \ - } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ - { \ - ctype_r* restrict p_r = ( ctype_r* )p; \ -\ - ctype* restrict c10; \ - ctype_r* restrict p10; \ - dim_t p10_dim, p10_len; \ - inc_t incc10, ldc10; \ - doff_t diagoffc10; \ - conj_t conjc10; \ -\ - ctype* restrict c12; \ - ctype_r* restrict p12; \ - dim_t p12_dim, p12_len; \ - inc_t incc12, ldc12; \ - doff_t diagoffc12; \ - conj_t conjc12; \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - diagoffc_abs = bli_abs( diagoffc ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs; \ - p10 = p_r; \ - c10 = c; \ - incc10 = incc; \ - ldc10 = ldc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - diagoffc12 = diagoffc_abs - j; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ - incc12 = ldc; \ - ldc12 = incc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc12 ); \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs + panel_dim; \ - diagoffc10 = diagoffc; \ - p10 = p_r; \ - c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ - incc10 = ldc; \ - ldc10 = incc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - incc12 = incc; \ - ldc12 = ldc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc10 ); \ - } \ -\ - /* Pack to p10. For upper storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc10, \ - schema, \ - p10_dim, \ - panel_dim_max, \ - p10_len, \ - p10_len, \ - kappa, \ - c10, incc10, ldc10, \ - ( ctype* )p10, ldp, \ - cntx \ - ); \ -\ - /* Pack to p12. For lower storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc12, \ - schema, \ - p12_dim, \ - panel_dim_max, \ - p12_len, \ - p12_len, \ - kappa, \ - c12, incc12, ldc12, \ - ( ctype* )p12, ldp, \ - cntx \ - ); \ -\ - /* Pack the stored triangle of c11 to p11. */ \ - { \ - dim_t j2 = diagoffc_abs; \ - /*ctype_r* restrict p_r = ( ctype_r* )p;*/ \ - ctype* restrict c11 = c + (j2 )*ldc; \ - ctype_r* restrict p11_r = p_r + (j2 )*ldp; \ -\ - PASTEMAC(ch,scal2rihs_mxn_uplo) \ - ( \ - schema, \ - uploc, \ - conjc, \ - panel_dim, \ - kappa, \ - c11, rs_c, cs_c, \ - p11_r, rs_p, cs_p \ - ); \ -\ - /* If we are packing a micro-panel with Hermitian structure, - we must take special care of the diagonal. Now, if kappa - were guaranteed to be unit, all we would need to do is - explicitly zero out the imaginary part of the diagonal of - p11, in case the diagonal of the source matrix contained - garbage (non-zero) imaginary values. HOWEVER, since kappa - can be non-unit, things become a little more complicated. - In general, we must re-apply the kappa scalar to ONLY the - real part of the diagonal of the source matrix and save - the result to the diagonal of p11. */ \ - if ( bli_is_hermitian( strucc ) ) \ - { \ - PASTEMAC3(ch,chr,ch,scal2rihs_mxn_diag) \ - ( \ - schema, \ - panel_dim, \ - panel_dim, \ - kappa, \ - c11, rs_c, cs_c, \ - p11_r, rs_p, cs_p \ - ); \ - } \ -\ -/* - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_r copied", m_panel_max, n_panel_max, \ - p_r + 0*is_p, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_i copied", m_panel_max, n_panel_max, \ - p_r + 1*is_p, rs_p, cs_p, "%4.1f", "" ); \ -*/ \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_rih, packm_cxk_rih ) - - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Pack the panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - schema, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, ldp, \ - cntx \ - ); \ -\ -\ - /* Tweak the panel according to its triangular structure */ \ - { \ - ctype_r* p_r = ( ctype_r* )p; \ -\ - dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = p_r + (j )*ldp; \ -\ - /* If the diagonal of c is implicitly unit, explicitly set the - the diagonal of the packed panel to kappa. */ \ - if ( bli_is_unit_diag( diagc ) ) \ - { \ - PASTEMAC(ch,setrihs_mxn_diag) \ - ( \ - schema, \ - panel_dim, \ - panel_dim, \ - kappa, \ - p11_r, rs_p, cs_p \ - ); \ - } \ -\ -\ - /* If requested, invert the diagonal of the packed panel. */ \ - if ( invdiag == TRUE ) \ - { \ - /* We don't need this case if we aren't supporting trsm. */ \ - } \ -\ -\ - /* Set the region opposite the diagonal of p to zero. To do this, - we need to reference the "unstored" region on the other side of - the diagonal. This amounts to toggling uploc and then shifting - the diagonal offset to shrink the newly referenced region (by - one diagonal). */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop = uploc; \ -\ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_rih, packm_cxk_rih ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_rih.h b/frame/1m/packm/bli_packm_struc_cxk_rih.h deleted file mode 100644 index e87767e268..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_rih.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_rih ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_rih ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool_t invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_rih ) - diff --git a/frame/1m/packm/bli_packm_thrinfo.c b/frame/1m/packm/bli_packm_thrinfo.c index 92162c4224..4b57971ef2 100644 --- a/frame/1m/packm/bli_packm_thrinfo.c +++ b/frame/1m/packm/bli_packm_thrinfo.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/1m/packm/bli_packm_thrinfo.h b/frame/1m/packm/bli_packm_thrinfo.h index 7d35cbc931..85b61931c1 100644 --- a/frame/1m/packm/bli_packm_thrinfo.h +++ b/frame/1m/packm/bli_packm_thrinfo.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/1m/packm/bli_packm_unb_var1.c b/frame/1m/packm/bli_packm_unb_var1.c deleted file mode 100644 index 6e72b3e9d0..0000000000 --- a/frame/1m/packm/bli_packm_unb_var1.c +++ /dev/null @@ -1,297 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T packm_fp - -typedef void (*FUNCPTR_T)( - struc_t strucc, - doff_t diagoffc, - diag_t diagc, - uplo_t uploc, - trans_t transc, - dim_t m, - dim_t n, - dim_t m_max, - dim_t n_max, - void* kappa, - void* c, inc_t rs_c, inc_t cs_c, - void* p, inc_t rs_p, inc_t cs_p, - cntx_t* cntx - ); - -static FUNCPTR_T GENARRAY(ftypes,packm_unb_var1); - - -void bli_packm_unb_var1 - ( - obj_t* c, - obj_t* p, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_cp = bli_obj_dt( c ); - - struc_t strucc = bli_obj_struc( c ); - doff_t diagoffc = bli_obj_diag_offset( c ); - diag_t diagc = bli_obj_diag( c ); - uplo_t uploc = bli_obj_uplo( c ); - trans_t transc = bli_obj_conjtrans_status( c ); - - dim_t m_p = bli_obj_length( p ); - dim_t n_p = bli_obj_width( p ); - dim_t m_max_p = bli_obj_padded_length( p ); - dim_t n_max_p = bli_obj_padded_width( p ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - void* buf_p = bli_obj_buffer_at_off( p ); - inc_t rs_p = bli_obj_row_stride( p ); - inc_t cs_p = bli_obj_col_stride( p ); - - void* buf_kappa; - - FUNCPTR_T f; - - - // This variant assumes that the computational kernel will always apply - // the alpha scalar of the higher-level operation. Thus, we use BLIS_ONE - // for kappa so that the underlying packm implementation does not scale - // during packing. - buf_kappa = bli_obj_buffer_for_const( dt_cp, &BLIS_ONE ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_cp]; - - if( bli_thread_am_ochief( thread ) ) { - // Invoke the function. - f - ( - strucc, - diagoffc, - diagc, - uploc, - transc, - m_p, - n_p, - m_max_p, - n_max_p, - buf_kappa, - buf_c, rs_c, cs_c, - buf_p, rs_p, cs_p, - cntx - ); - } -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - trans_t transc, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - cntx_t* cntx \ - ) \ -{ \ - ctype* restrict kappa_cast = kappa; \ - ctype* restrict c_cast = c; \ - ctype* restrict p_cast = p; \ - ctype* restrict zero = PASTEMAC(ch,0); \ -\ - /* We begin by packing the region indicated by the parameters. If - matrix c is dense (either because the structure is general or - because the structure has already been "densified"), this ends - up being the only action we take. Note that if kappa is unit, - the data is simply copied (rather than scaled by one). */ \ - PASTEMAC2(ch,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - diagoffc, \ - diagc, \ - uploc, \ - transc, \ - m, \ - n, \ - kappa_cast, \ - c_cast, rs_c, cs_c, \ - p_cast, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* If uploc is upper or lower, then the structure of c is necessarily - non-dense (ie: Hermitian, symmetric, or triangular, where part of the - matrix is unstored). In these cases, we want to fill in the unstored - part of the matrix. How this is done depends on the structure of c. */ \ - if ( bli_is_upper_or_lower( uploc ) ) \ - { \ - /* The Hermitian and symmetric cases are almost identical, so we - handle them in one conditional block. */ \ - if ( bli_is_hermitian( strucc ) || bli_is_symmetric( strucc ) ) \ - { \ - /* First we must reflect the region referenced to the opposite - side of the diagonal. */ \ - c_cast = c_cast + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_negate_diag_offset( &diagoffc ); \ - bli_toggle_trans( &transc ); \ - if ( bli_is_upper( uploc ) ) diagoffc += 1; \ - else if ( bli_is_lower( uploc ) ) diagoffc -= 1; \ -\ - /* If c is Hermitian, we need to apply a conjugation when - copying the region opposite the diagonal. */ \ - if ( bli_is_hermitian( strucc ) ) \ - transc = bli_trans_toggled_conj( transc ); \ -\ - /* Copy the data from the region opposite the diagonal of c - (as specified by the original value of diagoffc). Notice - that we use a diag parameter of non-unit since we can - assume nothing about the neighboring off-diagonal. */ \ - PASTEMAC2(ch,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - diagoffc, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - transc, \ - m, \ - n, \ - kappa_cast, \ - c_cast, rs_c, cs_c, \ - p_cast, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - else /* if ( bli_is_triangular( strucc ) ) */ \ - { \ - doff_t diagoffp = diagoffc; \ - uplo_t uplop = uploc; \ -\ - /* For this step we need the uplo and diagonal offset of p, which - we can derive from the parameters given. */ \ - if ( bli_does_trans( transc ) ) \ - { \ - bli_negate_diag_offset( &diagoffp ); \ - bli_toggle_uplo( &uplop ); \ - } \ -\ - /* For triangular matrices, we wish to reference the region - strictly opposite the diagonal of C. This amounts to - toggling uploc and then shifting the diagonal offset to - shrink the stored region (by one diagonal). */ \ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - /* Set the region opposite the diagonal of p to zero. */ \ - PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m, \ - n, \ - zero, \ - p_cast, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - /* The packed memory region was acquired/allocated with "aligned" - dimensions (ie: dimensions that were possibly inflated up to a - multiple). When these dimension are inflated, it creates empty - regions along the bottom and/or right edges of the matrix. If - eithe region exists, we set them to zero. This simplifies the - register level micro kernel in that it does not need to support - different register blockings for the edge cases. */ \ - if ( m != m_max ) \ - { \ - ctype* p_edge = p_cast + (m )*rs_p; \ -\ - PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_max - m, \ - n_max, \ - zero, \ - p_edge, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ - if ( n != n_max ) \ - { \ - ctype* p_edge = p_cast + (n )*cs_p; \ -\ - PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_max, \ - n_max - n, \ - zero, \ - p_edge, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNC_BASIC0( packm_unb_var1 ) - diff --git a/frame/1m/unpackm/bli_unpackm.h b/frame/1m/unpackm/bli_unpackm.h index b32d02d9ba..5e45428410 100644 --- a/frame/1m/unpackm/bli_unpackm.h +++ b/frame/1m/unpackm/bli_unpackm.h @@ -36,8 +36,6 @@ #include "bli_unpackm_check.h" #include "bli_unpackm_int.h" -#include "bli_unpackm_unb_var1.h" - #include "bli_unpackm_blk_var1.h" #include "bli_unpackm_cxk.h" diff --git a/frame/1m/unpackm/bli_unpackm_cntl.c b/frame/1m/unpackm/bli_unpackm_cntl.c index 46392269f4..95d0545bec 100644 --- a/frame/1m/unpackm/bli_unpackm_cntl.c +++ b/frame/1m/unpackm/bli_unpackm_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,20 +38,21 @@ cntl_t* bli_unpackm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* unpackm_var_func, + void_fp var_func, + void_fp unpackm_var_func, cntl_t* sub_node ) { cntl_t* cntl; unpackm_params_t* params; + err_t r_val; // NOTE: If this function is ever called, figure out whether the // bli_malloc_intl() below needs to be changed to bli_sba_acquire(). bli_abort(); // Allocate an unpackm_params_t struct. - params = bli_malloc_intl( sizeof( unpackm_params_t ) ); + params = bli_malloc_intl( sizeof( unpackm_params_t ), &r_val ); // Initialize the unpackm_params_t struct. params->size = sizeof( unpackm_params_t ); diff --git a/frame/1m/unpackm/bli_unpackm_cntl.h b/frame/1m/unpackm/bli_unpackm_cntl.h index b282c3561e..5c41d94657 100644 --- a/frame/1m/unpackm/bli_unpackm_cntl.h +++ b/frame/1m/unpackm/bli_unpackm_cntl.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,8 +49,8 @@ typedef struct unpackm_params_s unpackm_params_t; cntl_t* bli_unpackm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* unpackm_var_func, + void_fp var_func, + void_fp unpackm_var_func, cntl_t* sub_node ); diff --git a/frame/1m/unpackm/bli_unpackm_int.c b/frame/1m/unpackm/bli_unpackm_int.c index f4c8ab82df..550a8fb870 100644 --- a/frame/1m/unpackm/bli_unpackm_int.c +++ b/frame/1m/unpackm/bli_unpackm_int.c @@ -73,6 +73,6 @@ void bli_unpackm_int } // Barrier so that unpacking is done before computation. - bli_thread_obarrier( thread ); + bli_thread_barrier( thread ); } diff --git a/frame/1m/unpackm/bli_unpackm_unb_var1.c b/frame/1m/unpackm/bli_unpackm_unb_var1.c deleted file mode 100644 index c1033c2cb9..0000000000 --- a/frame/1m/unpackm/bli_unpackm_unb_var1.c +++ /dev/null @@ -1,131 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T unpackm_fp - -typedef void (*FUNCPTR_T)( - doff_t diagoffp, - uplo_t uplop, - trans_t transp, - dim_t m, - dim_t n, - void* p, inc_t rs_p, inc_t cs_p, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ); - -static FUNCPTR_T GENARRAY(ftypes,unpackm_unb_var1); - - -void bli_unpackm_unb_var1 - ( - obj_t* p, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_pc = bli_obj_dt( p ); - - doff_t diagoffp = bli_obj_diag_offset( p ); - uplo_t uplop = bli_obj_uplo( p ); - trans_t transc = bli_obj_onlytrans_status( c ); - - dim_t m_c = bli_obj_length( c ); - dim_t n_c = bli_obj_width( c ); - - void* buf_p = bli_obj_buffer_at_off( p ); - inc_t rs_p = bli_obj_row_stride( p ); - inc_t cs_p = bli_obj_col_stride( p ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - FUNCPTR_T f; - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_pc]; - - // Invoke the function. - f( diagoffp, - uplop, - transc, - m_c, - n_c, - buf_p, rs_p, cs_p, - buf_c, rs_c, cs_c, - cntx - ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, varname ) \ -\ -void PASTEMAC(ch,varname)( \ - doff_t diagoffp, \ - uplo_t uplop, \ - trans_t transp, \ - dim_t m, \ - dim_t n, \ - void* p, inc_t rs_p, inc_t cs_p, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx \ - ) \ -{ \ - ctype* p_cast = p; \ - ctype* c_cast = c; \ -\ - PASTEMAC2(ch,copym,BLIS_TAPI_EX_SUF) \ - ( \ - diagoffp,\ - BLIS_NONUNIT_DIAG, \ - uplop, \ - transp, \ - m, \ - n, \ - p_cast, rs_p, cs_p, \ - c_cast, rs_c, cs_c, \ - cntx, \ - NULL \ - ); \ -} - -INSERT_GENTFUNC_BASIC( unpackm, unpackm_unb_var1 ) - diff --git a/frame/2/bli_l2.h b/frame/2/bli_l2.h index 9415a0329b..ef4517c98d 100644 --- a/frame/2/bli_l2.h +++ b/frame/2/bli_l2.h @@ -40,18 +40,22 @@ // Prototype object APIs (expert and non-expert). #include "bli_oapi_ex.h" #include "bli_l2_oapi.h" +#include "bli_xapi_undef.h" #include "bli_oapi_ba.h" #include "bli_l2_oapi.h" +#include "bli_xapi_undef.h" // Prototype typed APIs (expert and non-expert). #include "bli_tapi_ex.h" #include "bli_l2_tapi.h" #include "bli_l2_ft.h" +#include "bli_xapi_undef.h" #include "bli_tapi_ba.h" #include "bli_l2_tapi.h" #include "bli_l2_ft.h" +#include "bli_xapi_undef.h" // Generate function pointer arrays for tapi functions (expert only). #include "bli_l2_fpa.h" diff --git a/frame/2/bli_l2_oapi.c b/frame/2/bli_l2_oapi.c index 25acb42076..cc32fb61e6 100644 --- a/frame/2/bli_l2_oapi.c +++ b/frame/2/bli_l2_oapi.c @@ -90,7 +90,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -157,7 +157,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -229,7 +229,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -293,7 +293,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -358,7 +358,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -422,7 +422,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/2/gemv/amd/bli_gemv_unf_var2_amd.c b/frame/2/gemv/amd/bli_gemv_unf_var2_amd.c new file mode 100644 index 0000000000..8f0f31479f --- /dev/null +++ b/frame/2/gemv/amd/bli_gemv_unf_var2_amd.c @@ -0,0 +1,222 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname, scalvsuf, axpyfsuf, fusefac ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + /*const num_t dt = PASTEMAC(ch,type);*/ \ +\ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_elem, &n_iter, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + /* y = beta * y; */ \ + /* NOTE: We don't explicitly handle the case where beta == 0 here + since that behavior is handled within the scalv kernel itself. */ \ + PASTEMAC2(ch,scalv,scalvsuf) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + beta, \ + y, incy, \ + cntx \ + ); \ +\ + /* If alpha == 0, then we are done. */ \ + if ( PASTEMAC(ch,eq0)( *alpha ) ) return; \ +\ + /*PASTECH(ch,axpyf_ker_ft) kfp_af;*/ \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + /*kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx );*/ \ + /*b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx );*/ \ + b_fuse = fusefac; \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + y1 = y + (0 )*incy; \ +\ + /* y = y + alpha * A1 * x1; */ \ + /*kfp_af*/ \ + PASTEMAC2(ch,axpyf,axpyfsuf) \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, rs_at, cs_at, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) +GENTFUNC( float, s, gemv_unf_var2, _zen_int10, _zen_int_5, 5 ) +GENTFUNC( double, d, gemv_unf_var2, _zen_int10, _zen_int_16x4, 4 ) +GENTFUNC( scomplex, c, gemv_unf_var2, _zen_int10, _zen_int_4, 4 ) +//GENTFUNC( dcomplex, z, gemv_unf_var2, _zen_int10, _ex, 1 ) + + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_elem, &n_iter, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + y1 = y + (0 )*incy; \ +\ + /* y = y + alpha * A1 * x1; */ \ + kfp_af \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, rs_at, cs_at, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) +GENTFUNC( dcomplex, z, gemv_unf_var2 ) + diff --git a/frame/2/gemv/bli_gemv_var_oapi.c b/frame/2/gemv/bli_gemv_var_oapi.c index 2e746b417f..8657735340 100644 --- a/frame/2/gemv/bli_gemv_var_oapi.c +++ b/frame/2/gemv/bli_gemv_var_oapi.c @@ -72,7 +72,7 @@ void PASTEMAC0(varname) \ void* buf_beta = bli_obj_buffer_for_1x1( dt, beta ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/gemv/other/bli_gemv_front.c b/frame/2/gemv/other/bli_gemv_front.c index 3fd1c8cf7a..680ca1abe1 100644 --- a/frame/2/gemv/other/bli_gemv_front.c +++ b/frame/2/gemv/other/bli_gemv_front.c @@ -53,9 +53,9 @@ void bli_gemv_front num_t dt_targ_a; num_t dt_targ_x; num_t dt_targ_y; - bool_t a_has_unit_inc; - bool_t x_has_unit_inc; - bool_t y_has_unit_inc; + bool a_has_unit_inc; + bool x_has_unit_inc; + bool y_has_unit_inc; obj_t alpha_local; obj_t beta_local; num_t dt_alpha; diff --git a/frame/2/ger/bli_ger_var_oapi.c b/frame/2/ger/bli_ger_var_oapi.c index 3fd95e89fb..f125efdf83 100644 --- a/frame/2/ger/bli_ger_var_oapi.c +++ b/frame/2/ger/bli_ger_var_oapi.c @@ -70,7 +70,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/ger/other/bli_ger_front.c b/frame/2/ger/other/bli_ger_front.c index 8f641fe2ea..e32ed0a78e 100644 --- a/frame/2/ger/other/bli_ger_front.c +++ b/frame/2/ger/other/bli_ger_front.c @@ -52,9 +52,9 @@ void bli_ger_front num_t dt_targ_x; num_t dt_targ_y; //num_t dt_targ_a; - bool_t x_has_unit_inc; - bool_t y_has_unit_inc; - bool_t a_has_unit_inc; + bool x_has_unit_inc; + bool y_has_unit_inc; + bool a_has_unit_inc; obj_t alpha_local; num_t dt_alpha; diff --git a/frame/2/hemv/bli_hemv_var_oapi.c b/frame/2/hemv/bli_hemv_var_oapi.c index 845f288c32..bf0e4b2022 100644 --- a/frame/2/hemv/bli_hemv_var_oapi.c +++ b/frame/2/hemv/bli_hemv_var_oapi.c @@ -73,7 +73,7 @@ void PASTEMAC0(varname) \ void* buf_beta = bli_obj_buffer_for_1x1( dt, beta ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/hemv/other/bli_hemv_front.c b/frame/2/hemv/other/bli_hemv_front.c index 1293f6b8e6..58bae5bf3b 100644 --- a/frame/2/hemv/other/bli_hemv_front.c +++ b/frame/2/hemv/other/bli_hemv_front.c @@ -53,9 +53,9 @@ void bli_hemv_front num_t dt_targ_a; num_t dt_targ_x; num_t dt_targ_y; - bool_t a_has_unit_inc; - bool_t x_has_unit_inc; - bool_t y_has_unit_inc; + bool a_has_unit_inc; + bool x_has_unit_inc; + bool y_has_unit_inc; obj_t alpha_local; obj_t beta_local; num_t dt_alpha; diff --git a/frame/2/her/bli_her_var_oapi.c b/frame/2/her/bli_her_var_oapi.c index ffca2e71e2..44c6d090d1 100644 --- a/frame/2/her/bli_her_var_oapi.c +++ b/frame/2/her/bli_her_var_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/her/other/bli_her_front.c b/frame/2/her/other/bli_her_front.c index 7753b28cf1..afb97378e0 100644 --- a/frame/2/her/other/bli_her_front.c +++ b/frame/2/her/other/bli_her_front.c @@ -50,8 +50,8 @@ void bli_her_front her_t* her_cntl; num_t dt_targ_x; //num_t dt_targ_c; - bool_t x_has_unit_inc; - bool_t c_has_unit_inc; + bool x_has_unit_inc; + bool c_has_unit_inc; obj_t alpha_local; num_t dt_alpha; diff --git a/frame/2/her2/bli_her2_var_oapi.c b/frame/2/her2/bli_her2_var_oapi.c index 2b26e5476f..dce87a1cd8 100644 --- a/frame/2/her2/bli_her2_var_oapi.c +++ b/frame/2/her2/bli_her2_var_oapi.c @@ -72,7 +72,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/her2/other/bli_her2_front.c b/frame/2/her2/other/bli_her2_front.c index 21c9027b55..cc7cbebc42 100644 --- a/frame/2/her2/other/bli_her2_front.c +++ b/frame/2/her2/other/bli_her2_front.c @@ -52,9 +52,9 @@ void bli_her2_front num_t dt_targ_x; num_t dt_targ_y; //num_t dt_targ_c; - bool_t x_has_unit_inc; - bool_t y_has_unit_inc; - bool_t c_has_unit_inc; + bool x_has_unit_inc; + bool y_has_unit_inc; + bool c_has_unit_inc; obj_t alpha_local; obj_t alpha_conj_local; num_t dt_alpha; diff --git a/frame/2/symv/other/bli_symv_front.c b/frame/2/symv/other/bli_symv_front.c index bac3f22435..ddccf645f7 100644 --- a/frame/2/symv/other/bli_symv_front.c +++ b/frame/2/symv/other/bli_symv_front.c @@ -53,9 +53,9 @@ void bli_symv_front num_t dt_targ_a; num_t dt_targ_x; num_t dt_targ_y; - bool_t a_has_unit_inc; - bool_t x_has_unit_inc; - bool_t y_has_unit_inc; + bool a_has_unit_inc; + bool x_has_unit_inc; + bool y_has_unit_inc; obj_t alpha_local; obj_t beta_local; num_t dt_alpha; diff --git a/frame/2/syr/other/bli_syr_front.c b/frame/2/syr/other/bli_syr_front.c index efbd24cf85..30a012c863 100644 --- a/frame/2/syr/other/bli_syr_front.c +++ b/frame/2/syr/other/bli_syr_front.c @@ -50,8 +50,8 @@ void bli_syr_front her_t* her_cntl; num_t dt_targ_x; num_t dt_targ_c; - bool_t x_has_unit_inc; - bool_t c_has_unit_inc; + bool x_has_unit_inc; + bool c_has_unit_inc; obj_t alpha_local; num_t dt_alpha; diff --git a/frame/2/syr2/other/bli_syr2_front.c b/frame/2/syr2/other/bli_syr2_front.c index 59a36f478e..f272ef24a1 100644 --- a/frame/2/syr2/other/bli_syr2_front.c +++ b/frame/2/syr2/other/bli_syr2_front.c @@ -52,9 +52,9 @@ void bli_syr2_front num_t dt_targ_x; num_t dt_targ_y; //num_t dt_targ_c; - bool_t x_has_unit_inc; - bool_t y_has_unit_inc; - bool_t c_has_unit_inc; + bool x_has_unit_inc; + bool y_has_unit_inc; + bool c_has_unit_inc; obj_t alpha_local; num_t dt_alpha; diff --git a/frame/2/trmv/bli_trmv_var_oapi.c b/frame/2/trmv/bli_trmv_var_oapi.c index 931eb2abbb..c74d312234 100644 --- a/frame/2/trmv/bli_trmv_var_oapi.c +++ b/frame/2/trmv/bli_trmv_var_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/trmv/other/bli_trmv_front.c b/frame/2/trmv/other/bli_trmv_front.c index 698f487107..e6d8c826d3 100644 --- a/frame/2/trmv/other/bli_trmv_front.c +++ b/frame/2/trmv/other/bli_trmv_front.c @@ -50,8 +50,8 @@ void bli_trmv_front trmv_t* trmv_cntl; num_t dt_targ_a; num_t dt_targ_x; - bool_t a_has_unit_inc; - bool_t x_has_unit_inc; + bool a_has_unit_inc; + bool x_has_unit_inc; obj_t alpha_local; num_t dt_alpha; diff --git a/frame/2/trmv/other/bli_trmv_int.c b/frame/2/trmv/other/bli_trmv_int.c index ed3ebc40b4..0524517270 100644 --- a/frame/2/trmv/other/bli_trmv_int.c +++ b/frame/2/trmv/other/bli_trmv_int.c @@ -68,7 +68,7 @@ void bli_trmv_int( obj_t* alpha, { varnum_t n; impl_t i; - bool_t uplo; + bool uplo; FUNCPTR_T f; obj_t a_local; diff --git a/frame/2/trsv/bli_trsv_var_oapi.c b/frame/2/trsv/bli_trsv_var_oapi.c index 4cf346acf5..62ac33e454 100644 --- a/frame/2/trsv/bli_trsv_var_oapi.c +++ b/frame/2/trsv/bli_trsv_var_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/trsv/other/bli_trsv_front.c b/frame/2/trsv/other/bli_trsv_front.c index 90a152e89a..9a70200a32 100644 --- a/frame/2/trsv/other/bli_trsv_front.c +++ b/frame/2/trsv/other/bli_trsv_front.c @@ -50,8 +50,8 @@ void bli_trsv_front trsv_t* trsv_cntl; num_t dt_targ_a; num_t dt_targ_x; - bool_t a_has_unit_inc; - bool_t x_has_unit_inc; + bool a_has_unit_inc; + bool x_has_unit_inc; obj_t alpha_local; num_t dt_alpha; diff --git a/frame/2/trsv/other/bli_trsv_int.c b/frame/2/trsv/other/bli_trsv_int.c index 0f2fe06b73..1dfcb75923 100644 --- a/frame/2/trsv/other/bli_trsv_int.c +++ b/frame/2/trsv/other/bli_trsv_int.c @@ -68,7 +68,7 @@ void bli_trsv_int( obj_t* alpha, { varnum_t n; impl_t i; - bool_t uplo; + bool uplo; FUNCPTR_T f; obj_t a_local; diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index 7f2879c027..4dc1a9d545 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,9 +35,11 @@ #include "bli_l3_cntl.h" #include "bli_l3_check.h" +#include "bli_l3_int.h" +#include "bli_l3_packab.h" // Define function types. -#include "bli_l3_ft_ex.h" +//#include "bli_l3_ft_ex.h" #include "bli_l3_ft_ukr.h" #include "bli_l3_oft.h" #include "bli_l3_oft_var.h" @@ -44,38 +47,46 @@ #include "bli_l3_blocksize.h" #include "bli_l3_direct.h" #include "bli_l3_prune.h" -#include "bli_l3_packm.h" +#include "bli_l3_schema.h" -// Prototype object APIs (expert and non-expert). -#include "bli_oapi_ex.h" +// Prototype object APIs (basic and expert). #include "bli_l3_oapi.h" +#include "bli_l3_oapi_ex.h" -#include "bli_oapi_ba.h" -#include "bli_l3_oapi.h" - -// Prototype typed APIs (expert and non-expert). -#include "bli_tapi_ex.h" +// Prototype typed APIs (basic and expert). #include "bli_l3_tapi.h" +#include "bli_l3_tapi_ex.h" -#include "bli_tapi_ba.h" -#include "bli_l3_tapi.h" +// Define function types for small/unpacked handlers/kernels. +#include "bli_l3_sup_oft.h" +#include "bli_l3_sup_ft_ker.h" + +// Define static edge case logic for use in small/unpacked kernels. +//#include "bli_l3_sup_edge.h" -// Prototype microkernel wrapper APIs +// Prototype object API to small/unpacked matrix dispatcher. +#include "bli_l3_sup.h" + +// Prototype reference implementation of small/unpacked matrix handler. +#include "bli_l3_sup_ref.h" +#include "bli_l3_sup_int.h" +#include "bli_l3_sup_vars.h" +#include "bli_l3_sup_packm_a.h" +#include "bli_l3_sup_packm_b.h" +#include "bli_l3_sup_packm_var.h" + +// Prototype microkernel wrapper APIs. #include "bli_l3_ukr_oapi.h" #include "bli_l3_ukr_tapi.h" // Generate function pointer arrays for tapi microkernel functions. #include "bli_l3_ukr_fpa.h" -// Operation-specific headers +// Operation-specific headers. #include "bli_gemm.h" #include "bli_hemm.h" -#include "bli_herk.h" -#include "bli_her2k.h" #include "bli_symm.h" -#include "bli_syrk.h" -#include "bli_syr2k.h" #include "bli_trmm.h" #include "bli_trmm3.h" #include "bli_trsm.h" - +#include "bli_gemmt.h" diff --git a/frame/3/bli_l3_blocksize.c b/frame/3/bli_l3_blocksize.c index b5993054a4..1986b3b0f6 100644 --- a/frame/3/bli_l3_blocksize.c +++ b/frame/3/bli_l3_blocksize.c @@ -51,8 +51,8 @@ dim_t bli_l3_determine_kc if ( family == BLIS_GEMM ) return bli_gemm_determine_kc( direct, i, dim, a, b, bszid, cntx ); - else if ( family == BLIS_HERK ) - return bli_herk_determine_kc( direct, i, dim, a, b, bszid, cntx ); + else if ( family == BLIS_GEMMT ) + return bli_gemmt_determine_kc( direct, i, dim, a, b, bszid, cntx ); else if ( family == BLIS_TRMM ) return bli_trmm_determine_kc( direct, i, dim, a, b, bszid, cntx ); else if ( family == BLIS_TRSM ) @@ -91,7 +91,7 @@ dim_t PASTEMAC0(opname) \ } GENFRONT( gemm_determine_kc, gemm ) -GENFRONT( herk_determine_kc, trmm ) +GENFRONT( gemmt_determine_kc, gemmt ) GENFRONT( trmm_determine_kc, trmm ) GENFRONT( trsm_determine_kc, trsm ) @@ -201,7 +201,7 @@ dim_t PASTEMAC0(opname) \ b_alg = bli_blksz_get_def( dt, bsize ); \ b_max = bli_blksz_get_max( dt, bsize ); \ \ - /* Notice that for herk, we do not need to perform any special handling + /* Notice that for gemmt, we do not need to perform any special handling for the default and maximum kc blocksizes vis-a-vis MR or NR. */ \ \ /* Call the bli_determine_blocksize_[fb]_sub() helper routine defined @@ -211,8 +211,8 @@ dim_t PASTEMAC0(opname) \ return b_use; \ } -GENFRONT( herk_determine_kc_f, f ) -GENFRONT( herk_determine_kc_b, b ) +GENFRONT( gemmt_determine_kc_f, f ) +GENFRONT( gemmt_determine_kc_b, b ) // ----------------------------------------------------------------------------- diff --git a/frame/3/bli_l3_blocksize.h b/frame/3/bli_l3_blocksize.h index c3301ee13a..3ea3c5aa02 100644 --- a/frame/3/bli_l3_blocksize.h +++ b/frame/3/bli_l3_blocksize.h @@ -60,7 +60,7 @@ dim_t PASTEMAC0(opname) \ ); GENPROT( gemm_determine_kc ) -GENPROT( herk_determine_kc ) +GENPROT( gemmt_determine_kc ) GENPROT( trmm_determine_kc ) GENPROT( trsm_determine_kc ) @@ -81,8 +81,8 @@ dim_t PASTEMAC0(opname) \ GENPROT( gemm_determine_kc_f ) GENPROT( gemm_determine_kc_b ) -GENPROT( herk_determine_kc_f ) -GENPROT( herk_determine_kc_b ) +GENPROT( gemmt_determine_kc_f ) +GENPROT( gemmt_determine_kc_b ) GENPROT( trmm_determine_kc_f ) GENPROT( trmm_determine_kc_b ) diff --git a/frame/3/bli_l3_check.c b/frame/3/bli_l3_check.c index 932f0346e6..3e7882bc39 100644 --- a/frame/3/bli_l3_check.c +++ b/frame/3/bli_l3_check.c @@ -53,7 +53,7 @@ void bli_gemm_check // Check object structure. // NOTE: Can't perform these checks as long as bli_gemm_check() is called - // from bli_gemm_int(), which is in the execution path for structured + // from bli_l3_int(), which is in the execution path for structured // level-3 operations such as hemm. //e_val = bli_check_general_object( a ); @@ -63,6 +63,28 @@ void bli_gemm_check //bli_check_error_code( e_val ); } +void bli_gemmt_check + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + err_t e_val; + + // Check basic properties of the operation. + + bli_gemmt_basic_check( alpha, a, b, beta, c, cntx ); + + // Check matrix squareness. + + e_val = bli_check_square_object( c ); + bli_check_error_code( e_val ); +} + void bli_hemm_check ( side_t side, @@ -76,7 +98,7 @@ void bli_hemm_check { err_t e_val; - // Perform checks common to hemm/symm. + // Perform checks common to hemm/symm/trmm/trsm. bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx ); @@ -87,7 +109,7 @@ void bli_hemm_check } void bli_herk_check - ( + ( obj_t* alpha, obj_t* a, obj_t* beta, @@ -175,7 +197,7 @@ void bli_symm_check } void bli_syrk_check - ( + ( obj_t* alpha, obj_t* a, obj_t* beta, @@ -226,7 +248,7 @@ void bli_syr2k_check bli_check_error_code( e_val ); } -void bli_trmm_check +void bli_trmm3_check ( side_t side, obj_t* alpha, @@ -239,7 +261,7 @@ void bli_trmm_check { err_t e_val; - // Perform checks common to hemm/symm. + // Perform checks common to hemm/symm/trmm/trsm. bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx ); @@ -249,22 +271,41 @@ void bli_trmm_check bli_check_error_code( e_val ); } +void bli_trmm_check + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx + ) +{ + err_t e_val; + + // Perform checks common to hemm/symm/trmm/trsm. + + bli_hemm_basic_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); + + // Check object structure. + + e_val = bli_check_triangular_object( a ); + bli_check_error_code( e_val ); +} + void bli_trsm_check ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, - obj_t* beta, - obj_t* c, cntx_t* cntx ) { err_t e_val; - // Perform checks common to hemm/symm. + // Perform checks common to hemm/symm/trmm/trsm. - bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx ); + bli_hemm_basic_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); // Check object structure. @@ -324,6 +365,28 @@ void bli_gemm_basic_check #endif } +void bli_gemmt_basic_check + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + err_t e_val; + + // Perform standard checks. + + bli_l3_basic_check( alpha, a, b, beta, c, cntx ); + + // Check object dimensions. + + e_val = bli_check_level3_dims( a, b, c ); + bli_check_error_code( e_val ); +} + void bli_hemm_basic_check ( side_t side, @@ -534,10 +597,5 @@ void bli_l3_basic_check e_val = bli_check_object_buffer( c ); bli_check_error_code( e_val ); - - // Check for sufficiently sized stack buffers - - e_val = bli_check_sufficient_stack_buf_size( bli_obj_dt( a ), cntx ); - bli_check_error_code( e_val ); } diff --git a/frame/3/bli_l3_check.h b/frame/3/bli_l3_check.h index 7d30bb1846..c600d60b9a 100644 --- a/frame/3/bli_l3_check.h +++ b/frame/3/bli_l3_check.h @@ -51,6 +51,7 @@ void PASTEMAC(opname,_check) \ ); GENPROT( gemm ) +GENPROT( gemmt ) GENPROT( her2k ) GENPROT( syr2k ) @@ -71,8 +72,7 @@ void PASTEMAC(opname,_check) \ GENPROT( hemm ) GENPROT( symm ) -GENPROT( trmm ) -GENPROT( trsm ) +GENPROT( trmm3 ) #undef GENPROT @@ -91,6 +91,22 @@ GENPROT( herk ) GENPROT( syrk ) +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx \ + ); + +GENPROT( trmm ) +GENPROT( trsm ) + + // ----------------------------------------------------------------------------- void bli_gemm_basic_check @@ -103,6 +119,16 @@ void bli_gemm_basic_check cntx_t* cntx ); +void bli_gemmt_basic_check + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ); + void bli_hemm_basic_check ( side_t side, diff --git a/frame/3/bli_l3_cntl.c b/frame/3/bli_l3_cntl.c index efdca53dbd..83ff8e5af5 100644 --- a/frame/3/bli_l3_cntl.c +++ b/frame/3/bli_l3_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,10 +54,17 @@ void bli_l3_cntl_create_if if ( cntl_orig == NULL ) { if ( family == BLIS_GEMM || - family == BLIS_HERK || + family == BLIS_GEMMT || family == BLIS_TRMM ) { - *cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b ); + *cntl_use = bli_gemm_cntl_create + ( + rntm, + family, + schema_a, + schema_b, + bli_obj_ker_fn( c ) + ); } else // if ( family == BLIS_TRSM ) { @@ -66,7 +73,14 @@ void bli_l3_cntl_create_if if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT; else side = BLIS_RIGHT; - *cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b ); + *cntl_use = bli_trsm_cntl_create + ( + rntm, + side, + schema_a, + schema_b, + bli_obj_ker_fn( c ) + ); } } else @@ -97,7 +111,7 @@ void bli_l3_cntl_free opid_t family = bli_cntl_family( cntl_use ); if ( family == BLIS_GEMM || - family == BLIS_HERK || + family == BLIS_GEMMT || family == BLIS_TRMM ) { bli_gemm_cntl_free( rntm, cntl_use, thread ); diff --git a/frame/3/bli_l3_cntl.h b/frame/3/bli_l3_cntl.h index 0c04f348cc..c308c8a964 100644 --- a/frame/3/bli_l3_cntl.h +++ b/frame/3/bli_l3_cntl.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_direct.c b/frame/3/bli_l3_direct.c index 7baf2d6ef5..0d0a719214 100644 --- a/frame/3/bli_l3_direct.c +++ b/frame/3/bli_l3_direct.c @@ -46,7 +46,7 @@ dir_t bli_l3_direct opid_t family = bli_cntl_family( cntl ); if ( family == BLIS_GEMM ) return bli_gemm_direct( a, b, c ); - else if ( family == BLIS_HERK ) return bli_herk_direct( a, b, c ); + else if ( family == BLIS_GEMMT ) return bli_gemmt_direct( a, b, c ); else if ( family == BLIS_TRMM ) return bli_trmm_direct( a, b, c ); else if ( family == BLIS_TRSM ) return bli_trsm_direct( a, b, c ); @@ -68,14 +68,14 @@ dir_t bli_gemm_direct return BLIS_FWD; } -dir_t bli_herk_direct +dir_t bli_gemmt_direct ( obj_t* a, obj_t* b, obj_t* c ) { - // For herk, movement may be forwards (or backwards). + // For gemmt, movement may be forwards (or backwards). return BLIS_FWD; } diff --git a/frame/3/bli_l3_direct.h b/frame/3/bli_l3_direct.h index 7383c4a9fb..39798407a2 100644 --- a/frame/3/bli_l3_direct.h +++ b/frame/3/bli_l3_direct.h @@ -53,7 +53,7 @@ dir_t PASTEMAC0(opname) \ ); GENPROT( gemm_direct ) -GENPROT( herk_direct ) +GENPROT( gemmt_direct ) GENPROT( trmm_direct ) GENPROT( trsm_direct ) diff --git a/frame/3/bli_l3_ft_ukr.h b/frame/3/bli_l3_ft_ukr.h index 4249dcbd6b..28065c208b 100644 --- a/frame/3/bli_l3_ft_ukr.h +++ b/frame/3/bli_l3_ft_ukr.h @@ -47,6 +47,8 @@ \ typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -67,6 +69,8 @@ INSERT_GENTDEF( gemm ) \ typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a1x, \ diff --git a/frame/ind/bli_l3_ind.c b/frame/3/bli_l3_ind.c similarity index 60% rename from frame/ind/bli_l3_ind.c rename to frame/3/bli_l3_ind.c index 10897c3491..fbf73be608 100644 --- a/frame/ind/bli_l3_ind.c +++ b/frame/3/bli_l3_ind.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,50 +35,32 @@ #include "blis.h" -static void* bli_l3_ind_oper_fp[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = +// This array tracks whether a particular operation is implemented for each of +// the induced methods. +static bool bli_l3_ind_oper_impl[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = { - /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm */ -/* 3mh */ { bli_gemm3mh, bli_hemm3mh, bli_herk3mh, bli_her2k3mh, bli_symm3mh, - bli_syrk3mh, bli_syr2k3mh, bli_trmm33mh, NULL, NULL }, -/* 3m1 */ { bli_gemm3m1, bli_hemm3m1, bli_herk3m1, bli_her2k3m1, bli_symm3m1, - bli_syrk3m1, bli_syr2k3m1, bli_trmm33m1, bli_trmm3m1, bli_trsm3m1 }, -/* 4mh */ { bli_gemm4mh, bli_hemm4mh, bli_herk4mh, bli_her2k4mh, bli_symm4mh, - bli_syrk4mh, bli_syr2k4mh, bli_trmm34mh, NULL, NULL }, -/* 4mb */ { bli_gemm4mb, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL }, -/* 4m1 */ { bli_gemm4m1, bli_hemm4m1, bli_herk4m1, bli_her2k4m1, bli_symm4m1, - bli_syrk4m1, bli_syr2k4m1, bli_trmm34m1, bli_trmm4m1, bli_trsm4m1 }, -/* 1m */ { bli_gemm1m, bli_hemm1m, bli_herk1m, bli_her2k1m, bli_symm1m, - bli_syrk1m, bli_syr2k1m, bli_trmm31m, bli_trmm1m, bli_trsm1m }, -/* nat */ { bli_gemmnat, bli_hemmnat, bli_herknat, bli_her2knat, bli_symmnat, - bli_syrknat, bli_syr2knat, bli_trmm3nat, bli_trmmnat, bli_trsmnat }, + /* gemm gemmt hemm herk her2k symm syrk syr2k trmm3 trmm trsm */ +/* 1m */ { TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE }, +/* nat */ { TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE } }; // // NOTE: "2" is used instead of BLIS_NUM_FP_TYPES/2. // -// BLIS provides APIs to modify this state during runtime. So, one application thread -// can modify the state, before another starts the corresponding BLIS operation. -// This is solved by making the induced method status array local to threads. +// BLIS provides APIs to modify this state during runtime. So, it's possible for one +// application thread to modify the state before another starts the corresponding +// BLIS operation. This is solved by making the induced method status array local to +// threads. static BLIS_THREAD_LOCAL -bool_t bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = +bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = { - /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm */ + /* gemm gemmt hemm herk her2k symm + syrk syr2k trmm3 trmm trsm */ /* c z */ -/* 3mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, +/* 1m */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 3m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 4mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 4mb */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 4m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 1m */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* nat */ { {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, +/* nat */ { {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE} }, }; @@ -87,24 +69,19 @@ bool_t bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = #undef GENFUNC #define GENFUNC( opname, optype ) \ \ -void* PASTEMAC(opname,ind_get_avail)( num_t dt ) \ +ind_t PASTEMAC(opname,ind_find_avail)( num_t dt ) \ { \ - return bli_ind_oper_get_avail( optype, dt ); \ + return bli_l3_ind_oper_find_avail( optype, dt ); \ } -/* -bool_t PASTEMAC(opname,ind_has_avail)( num_t dt ) -{ - return bli_ind_oper_has_avail( optype, dt ); -} -*/ +//bool PASTEMAC(opname,ind_has_avail)( num_t dt ) +//{ +// return bli_ind_oper_has_avail( optype, dt ); +//} GENFUNC( gemm, BLIS_GEMM ) +GENFUNC( gemmt, BLIS_GEMMT ) GENFUNC( hemm, BLIS_HEMM ) -GENFUNC( herk, BLIS_HERK ) -GENFUNC( her2k, BLIS_HER2K ) GENFUNC( symm, BLIS_SYMM ) -GENFUNC( syrk, BLIS_SYRK ) -GENFUNC( syr2k, BLIS_SYR2K ) GENFUNC( trmm3, BLIS_TRMM3 ) GENFUNC( trmm, BLIS_TRMM ) GENFUNC( trsm, BLIS_TRSM ) @@ -112,18 +89,18 @@ GENFUNC( trsm, BLIS_TRSM ) // ----------------------------------------------------------------------------- #if 0 -bool_t bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ) +bool bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ) { - void* func; - bool_t stat; + bool enabled; + bool stat; // If the datatype is real, it is never available. if ( !bli_is_complex( dt ) ) return FALSE; - func = bli_l3_ind_oper_get_func( oper, method ); - stat = bli_l3_ind_oper_get_enable( oper, method, dt ); + enabled = bli_l3_ind_oper_is_impl( oper, method ); + stat = bli_l3_ind_oper_get_enable( oper, method, dt ); - return ( func != NULL && stat == TRUE ); + return ( enabled == TRUE && stat == TRUE ); } #endif @@ -146,11 +123,11 @@ ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ) // current operation and datatype. for ( im = 0; im < BLIS_NUM_IND_METHODS; ++im ) { - void* func = bli_l3_ind_oper_get_func( oper, im ); - bool_t stat = bli_l3_ind_oper_get_enable( oper, im, dt ); + bool enabled = bli_l3_ind_oper_is_impl( oper, im ); + bool stat = bli_l3_ind_oper_get_enable( oper, im, dt ); - if ( func != NULL && - stat == TRUE ) return im; + if ( enabled == TRUE && + stat == TRUE ) return im; } // This return statement should never execute since the native index @@ -161,7 +138,7 @@ ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ) // ----------------------------------------------------------------------------- -void bli_l3_ind_set_enable_dt( ind_t method, num_t dt, bool_t status ) +void bli_l3_ind_set_enable_dt( ind_t method, num_t dt, bool status ) { opid_t iop; @@ -197,7 +174,7 @@ void bli_l3_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ) } } -void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool_t status ) +void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool status ) { ind_t im; @@ -217,7 +194,7 @@ void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool_t status ) // A mutex to allow synchronous access to the bli_l3_ind_oper_st array. static bli_pthread_mutex_t oper_st_mutex = BLIS_PTHREAD_MUTEX_INITIALIZER; -void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool_t status ) +void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool status ) { num_t idt; @@ -242,10 +219,10 @@ void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool_t sta bli_pthread_mutex_unlock( &oper_st_mutex ); } -bool_t bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ) +bool bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ) { - num_t idt = bli_ind_map_cdt_to_index( dt ); - bool_t r_val; + num_t idt = bli_ind_map_cdt_to_index( dt ); + bool r_val; { r_val = bli_l3_ind_oper_st[ method ][ oper ][ idt ]; @@ -256,8 +233,7 @@ bool_t bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ) // ----------------------------------------------------------------------------- -void* bli_l3_ind_oper_get_func( opid_t oper, ind_t method ) +bool bli_l3_ind_oper_is_impl( opid_t oper, ind_t method ) { - return bli_l3_ind_oper_fp[ method ][ oper ]; + return bli_l3_ind_oper_impl[ method ][ oper ]; } - diff --git a/frame/ind/bli_l3_ind.h b/frame/3/bli_l3_ind.h similarity index 73% rename from frame/ind/bli_l3_ind.h rename to frame/3/bli_l3_ind.h index 693d3a4c2c..a14ad783c9 100644 --- a/frame/ind/bli_l3_ind.h +++ b/frame/3/bli_l3_ind.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,35 +41,32 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void* PASTEMAC(opname,ind_get_avail)( num_t dt ); -/*bool_t PASTEMAC(opname,ind_has_avail)( num_t dt ); */ +ind_t PASTEMAC(opname,ind_find_avail)( num_t dt ); +/*bool PASTEMAC(opname,ind_has_avail)( num_t dt ); */ GENPROT( gemm ) +GENPROT( gemmt ) GENPROT( hemm ) -GENPROT( herk ) -GENPROT( her2k ) GENPROT( symm ) -GENPROT( syrk ) -GENPROT( syr2k ) GENPROT( trmm3 ) GENPROT( trmm ) GENPROT( trsm ) // ----------------------------------------------------------------------------- -//bool_t bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ); +//bool bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ); -ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ); +ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ); -void bli_l3_ind_set_enable_dt( ind_t method, num_t dt, bool_t status ); +void bli_l3_ind_set_enable_dt( ind_t method, num_t dt, bool status ); -void bli_l3_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); -void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool_t status ); +void bli_l3_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); +void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool status ); -void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool_t status ); -bool_t bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ); +void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool status ); +bool bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ); -void* bli_l3_ind_oper_get_func( opid_t oper, ind_t method ); +bool bli_l3_ind_oper_is_impl( opid_t oper, ind_t method ); #endif diff --git a/frame/ind/ukernels/bli_l3_ind_ukr.h b/frame/3/bli_l3_ind_ukr.h similarity index 84% rename from frame/ind/ukernels/bli_l3_ind_ukr.h rename to frame/3/bli_l3_ind_ukr.h index 53cb0b6f88..6f24e71fcf 100644 --- a/frame/ind/ukernels/bli_l3_ind_ukr.h +++ b/frame/3/bli_l3_ind_ukr.h @@ -43,6 +43,8 @@ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -53,11 +55,6 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( gemm3mh_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm3m1_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm4mh_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm4mb_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm4m1_ukr_name ) INSERT_GENTPROT_BASIC0( gemm1m_ukr_name ) @@ -66,6 +63,8 @@ INSERT_GENTPROT_BASIC0( gemm1m_ukr_name ) \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a1x, \ @@ -77,10 +76,6 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( gemmtrsm3m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( gemmtrsm3m1_u_ukr_name ) -INSERT_GENTPROT_BASIC0( gemmtrsm4m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( gemmtrsm4m1_u_ukr_name ) INSERT_GENTPROT_BASIC0( gemmtrsm1m_l_ukr_name ) INSERT_GENTPROT_BASIC0( gemmtrsm1m_u_ukr_name ) @@ -97,10 +92,6 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( trsm3m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( trsm3m1_u_ukr_name ) -INSERT_GENTPROT_BASIC0( trsm4m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( trsm4m1_u_ukr_name ) INSERT_GENTPROT_BASIC0( trsm1m_l_ukr_name ) INSERT_GENTPROT_BASIC0( trsm1m_u_ukr_name ) diff --git a/frame/3/trsm/bli_trsm_int.c b/frame/3/bli_l3_int.c similarity index 74% rename from frame/3/trsm/bli_trsm_int.c rename to frame/3/bli_l3_int.c index dc39e69e0d..d4b974030c 100644 --- a/frame/3/trsm/bli_trsm_int.c +++ b/frame/3/bli_l3_int.c @@ -34,7 +34,7 @@ #include "blis.h" -void bli_trsm_int +void bli_l3_int ( obj_t* alpha, obj_t* a, @@ -47,10 +47,9 @@ void bli_trsm_int thrinfo_t* thread ) { - obj_t a_local; - obj_t b_local; - obj_t c_local; - trsm_var_oft f; + obj_t a_local; + obj_t b_local; + obj_t c_local; // Return early if the current control tree node is NULL. if ( bli_cntl_is_null( cntl ) ) return; @@ -60,72 +59,82 @@ void bli_trsm_int bli_gemm_basic_check( alpha, a, b, beta, c, cntx ); // If C has a zero dimension, return early. - if ( bli_obj_has_zero_dim( c ) ) return; + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } // If A or B has a zero dimension, scale C by beta and return early. if ( bli_obj_has_zero_dim( a ) || bli_obj_has_zero_dim( b ) ) { if ( bli_thread_am_ochief( thread ) ) - bli_scalm( beta, c ); - bli_thread_obarrier( thread ); + bli_scalm( beta, c ); + bli_thread_barrier( thread ); return; } - // Alias A and B in case we need to update attached scalars. + // If A or B is marked as being filled with zeros, scale C by beta and + // return early. + if ( bli_obj_is_zeros( a ) || + bli_obj_is_zeros( b ) ) + { + // This should never execute. + bli_abort(); + + if ( bli_thread_am_ochief( thread ) ) + bli_scalm( beta, c ); + bli_thread_barrier( thread ); + return; + } + + // Alias A, B, and C in case we need to update attached scalars. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); - - // Alias C in case we need to induce a transposition. bli_obj_alias_to( c, &c_local ); + // Ensure that a valid packing function is set on A and B. + if ( !bli_obj_pack_fn( &a_local ) ) + bli_obj_set_pack_fn( bli_packm_blk_var1, &a_local ); + + if ( !bli_obj_pack_fn( &b_local ) ) + bli_obj_set_pack_fn( bli_packm_blk_var1, &b_local ); + // If we are about to call a leaf-level implementation, and matrix C // still needs a transposition, then we must induce one by swapping the // strides and dimensions. Note that this transposition would normally // be handled explicitly in the packing of C, but if C is not being // packed, this is our last chance to handle the transposition. - if ( bli_cntl_is_leaf( cntl ) && bli_obj_has_trans( c ) ) + //if ( bli_cntl_is_leaf( cntl ) && bli_obj_has_trans( c ) ) + if ( bli_obj_has_trans( c ) ) { bli_obj_induce_trans( &c_local ); bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &c_local ); } - // If beta is non-unit, apply it to the scalar attached to C. - if ( !bli_obj_equals( beta, &BLIS_ONE ) ) - { - bli_obj_scalar_apply_scalar( beta, &c_local ); - } - - // Set two bools: one based on the implied side parameter (the structure - // of the root object) and one based on the uplo field of the triangular - // matrix's root object (whether that is matrix A or matrix B). - if ( bli_obj_root_is_triangular( a ) ) + // If alpha is non-unit, typecast and apply it to the scalar attached + // to B, unless it happens to be triangular. + if ( bli_obj_root_is_triangular( b ) ) { - // If alpha is non-unit, typecast and apply it to the scalar - // attached to B (the non-triangular matrix). if ( !bli_obj_equals( alpha, &BLIS_ONE ) ) - { - bli_obj_scalar_apply_scalar( alpha, &b_local ); - } + bli_obj_scalar_apply_scalar( alpha, &a_local ); } else // if ( bli_obj_root_is_triangular( b ) ) { - // If alpha is non-unit, typecast and apply it to the scalar - // attached to A (the non-triangular matrix). if ( !bli_obj_equals( alpha, &BLIS_ONE ) ) - { - bli_obj_scalar_apply_scalar( alpha, &a_local ); - } + bli_obj_scalar_apply_scalar( alpha, &b_local ); } - // FGVZ->TMS: Is this barrier still needed? - bli_thread_obarrier( thread ); + // If beta is non-unit, typecast and apply it to the scalar attached + // to C. + if ( !bli_obj_equals( beta, &BLIS_ONE ) ) + bli_obj_scalar_apply_scalar( beta, &c_local ); // Create the next node in the thrinfo_t structure. bli_thrinfo_grow( rntm, cntl, thread ); // Extract the function pointer from the current control tree node. - f = bli_cntl_var_func( cntl ); + l3_var_oft f = bli_cntl_var_func( cntl ); // Invoke the variant. f diff --git a/frame/3/trsm/bli_trsm_int.h b/frame/3/bli_l3_int.h similarity index 99% rename from frame/3/trsm/bli_trsm_int.h rename to frame/3/bli_l3_int.h index aabb2a8aa6..d76b0ac3e2 100644 --- a/frame/3/trsm/bli_trsm_int.h +++ b/frame/3/bli_l3_int.h @@ -32,7 +32,7 @@ */ -void bli_trsm_int +void bli_l3_int ( obj_t* alpha, obj_t* a, diff --git a/frame/3/bli_l3_oapi.c b/frame/3/bli_l3_oapi.c index d9ba273699..1df8e80123 100644 --- a/frame/3/bli_l3_oapi.c +++ b/frame/3/bli_l3_oapi.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,55 +32,31 @@ */ -// Guard the function definitions so that they are only compiled when -// #included from files that define the object API macros. -#ifdef BLIS_ENABLE_OAPI +#include "blis.h" // -// Define object-based interfaces. +// Define object-based interfaces (basic). // #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if each of the operands have a - complex storage datatype. NOTE: Allowing precisions to vary while - using 1m, which is what we do here, is unique to gemm; other level-3 - operations use 1m only if all storage datatypes are equal (including - the computation datatype). If any operands are real, skip the induced - method chooser function and proceed directly with native execution. */ \ - if ( bli_obj_is_complex( c ) && \ - bli_obj_is_complex( a ) && \ - bli_obj_is_complex( b ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( alpha, a, b, beta, c, NULL, NULL ); \ } GENFRONT( gemm ) +GENFRONT( gemmt ) GENFRONT( her2k ) GENFRONT( syr2k ) @@ -88,7 +64,7 @@ GENFRONT( syr2k ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ @@ -96,32 +72,11 @@ void PASTEMAC(opname,EX_SUF) \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ - bli_obj_dt( b ) == bli_obj_dt( c ) && \ - bli_obj_is_complex( c ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( side, alpha, a, b, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, beta, c, cntx, rntm ); \ - } \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( side, alpha, a, b, beta, c, NULL, NULL ); \ } GENFRONT( hemm ) @@ -132,37 +87,17 @@ GENFRONT( trmm3 ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ - bli_obj_is_complex( c ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( alpha, a, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, beta, c, cntx, rntm ); \ - } \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( alpha, a, beta, c, NULL, NULL ); \ } GENFRONT( herk ) @@ -172,42 +107,19 @@ GENFRONT( syrk ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ obj_t* a, \ obj_t* b \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && \ - bli_obj_is_complex( b ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( side, alpha, a, b, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, cntx, rntm ); \ - } \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( side, alpha, a, b, NULL, NULL ); \ } GENFRONT( trmm ) GENFRONT( trsm ) - -#endif - diff --git a/frame/3/bli_l3_oapi.h b/frame/3/bli_l3_oapi.h index 4f9f20608e..e00f238add 100644 --- a/frame/3/bli_l3_oapi.h +++ b/frame/3/bli_l3_oapi.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,23 +35,23 @@ // -// Prototype object-based interfaces. +// Prototype object-based interfaces (basic). // #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( gemm ) +GENPROT( gemmt ) GENPROT( her2k ) GENPROT( syr2k ) @@ -58,7 +59,7 @@ GENPROT( syr2k ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ @@ -66,7 +67,6 @@ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( hemm ) @@ -77,13 +77,12 @@ GENPROT( trmm3 ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( herk ) @@ -93,13 +92,12 @@ GENPROT( syrk ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ obj_t* a, \ obj_t* b \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( trmm ) diff --git a/frame/3/bli_l3_oapi_ex.c b/frame/3/bli_l3_oapi_ex.c index 76f4fe16ab..cd0df7017c 100644 --- a/frame/3/bli_l3_oapi_ex.c +++ b/frame/3/bli_l3_oapi_ex.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,13 +34,512 @@ #include "blis.h" -// Include cpp macros that instantiate the API definition templates as -// having expert parameters. -#include "bli_oapi_ex.h" +// +// Define object-based interfaces (expert). +// -// Define the macro protecting the object API definitions. -#define BLIS_ENABLE_OAPI +// If a sandbox was enabled, we forgo defining bli_gemm_ex() since it will be +// defined in the sandbox environment. +#ifndef BLIS_ENABLE_SANDBOX -// Include the object API definitions here. -#include "bli_l3_oapi.c" +void PASTEMAC(gemm,BLIS_OAPI_EX_SUF) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + // If the rntm is non-NULL, it may indicate that we should forgo sup + // handling altogether. + bool enable_sup = TRUE; + if ( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); + + if ( enable_sup ) + { + // Execute the small/unpacked oapi handler. If it finds that the problem + // does not fall within the thresholds that define "small", or for some + // other reason decides not to use the small/unpacked implementation, + // the function returns with BLIS_FAILURE, which causes execution to + // proceed towards the conventional implementation. + err_t result = bli_gemmsup( alpha, a, b, beta, c, cntx, rntm ); + if ( result == BLIS_SUCCESS ) + { + return; + } + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( c ); + ind_t im = BLIS_NAT; + + // If each matrix operand has a complex storage datatype, try to get an + // induced method (if one is available and enabled). NOTE: Allowing + // precisions to vary while using 1m, which is what we do here, is unique + // to gemm; other level-3 operations use 1m only if all storage datatypes + // are equal (and they ignore the computation precision). + if ( bli_obj_is_complex( c ) && + bli_obj_is_complex( a ) && + bli_obj_is_complex( b ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_gemmind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_gemm_front( alpha, a, b, beta, c, cntx, rntm, NULL ); +} + +#endif + + +void PASTEMAC(gemmt,BLIS_OAPI_EX_SUF) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( c ); + ind_t im = BLIS_NAT; + + // If all matrix operands are complex and of the same storage datatype, try + // to get an induced method (if one is available and enabled). + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_dt( b ) == bli_obj_dt( c ) && + bli_obj_is_complex( c ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_gemmtind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemmt_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_gemmt_front( alpha, a, b, beta, c, cntx, rntm, NULL ); +} + + +void PASTEMAC(her2k,BLIS_OAPI_EX_SUF) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + obj_t ah; + obj_t bh; + obj_t alphah; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_her2k_check( alpha, a, b, beta, c, cntx ); + + bli_obj_alias_to( alpha, &alphah ); + bli_obj_toggle_conj( &alphah ); + + bli_obj_alias_to( a, &ah ); + bli_obj_toggle_trans( &ah ); + bli_obj_toggle_conj( &ah ); + + bli_obj_alias_to( b, &bh ); + bli_obj_toggle_trans( &bh ); + bli_obj_toggle_conj( &bh ); + + // Invoke gemmt twice, using beta only the first time. + PASTEMAC(gemmt,BLIS_OAPI_EX_SUF)( alpha, a, &bh, beta, c, cntx, rntm ); + PASTEMAC(gemmt,BLIS_OAPI_EX_SUF)( &alphah, b, &ah, &BLIS_ONE, c, cntx, rntm ); + + // The Hermitian rank-2k product was computed as alpha*A*B'+alpha'*B*A', even for + // the diagonal elements. Mathematically, the imaginary components of + // diagonal elements of a Hermitian rank-2k product should always be + // zero. However, in practice, they sometimes accumulate meaningless + // non-zero values. To prevent this, we explicitly set those values + // to zero before returning. + bli_setid( &BLIS_ZERO, c ); +} + + +void PASTEMAC(syr2k,BLIS_OAPI_EX_SUF) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + obj_t at; + obj_t bt; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_syr2k_check( alpha, a, b, beta, c, cntx ); + + bli_obj_alias_to( b, &bt ); + bli_obj_toggle_trans( &bt ); + + bli_obj_alias_to( a, &at ); + bli_obj_toggle_trans( &at ); + + // Invoke gemmt twice, using beta only the first time. + PASTEMAC(gemmt,BLIS_OAPI_EX_SUF)( alpha, a, &bt, beta, c, cntx, rntm ); + PASTEMAC(gemmt,BLIS_OAPI_EX_SUF)( alpha, b, &at, &BLIS_ONE, c, cntx, rntm ); +} + + +void PASTEMAC(hemm,BLIS_OAPI_EX_SUF) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( c ); + ind_t im = BLIS_NAT; + + // If all matrix operands are complex and of the same storage datatype, try + // to get an induced method (if one is available and enabled). + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_dt( b ) == bli_obj_dt( c ) && + bli_obj_is_complex( c ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_hemmind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_hemm_check( side, alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_hemm_front( side, alpha, a, b, beta, c, cntx, rntm, NULL ); +} + + +void PASTEMAC(symm,BLIS_OAPI_EX_SUF) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( c ); + ind_t im = BLIS_NAT; + + // If all matrix operands are complex and of the same storage datatype, try + // to get an induced method (if one is available and enabled). + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_dt( b ) == bli_obj_dt( c ) && + bli_obj_is_complex( c ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_symmind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_symm_check( side, alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_symm_front( side, alpha, a, b, beta, c, cntx, rntm, NULL ); +} + + +void PASTEMAC(trmm3,BLIS_OAPI_EX_SUF) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( c ); + ind_t im = BLIS_NAT; + + // If all matrix operands are complex and of the same storage datatype, try + // to get an induced method (if one is available and enabled). + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_dt( b ) == bli_obj_dt( c ) && + bli_obj_is_complex( c ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_trmm3ind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_trmm3_check( side, alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_trmm3_front( side, alpha, a, b, beta, c, cntx, rntm, NULL ); +} + + +void PASTEMAC(herk,BLIS_OAPI_EX_SUF) + ( + obj_t* alpha, + obj_t* a, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + obj_t ah; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_herk_check( alpha, a, beta, c, cntx ); + + bli_obj_alias_to( a, &ah ); + bli_obj_toggle_trans( &ah ); + bli_obj_toggle_conj( &ah ); + + PASTEMAC(gemmt,BLIS_OAPI_EX_SUF)( alpha, a, &ah, beta, c, cntx, rntm ); + + // The Hermitian rank-k product was computed as Re(alpha)*A*A', even for the + // diagonal elements. Mathematically, the imaginary components of + // diagonal elements of a Hermitian rank-k product should always be + // zero. However, in practice, they sometimes accumulate meaningless + // non-zero values. To prevent this, we explicitly set those values + // to zero before returning. + bli_setid( &BLIS_ZERO, c ); +} + + +void PASTEMAC(syrk,BLIS_OAPI_EX_SUF) + ( + obj_t* alpha, + obj_t* a, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + obj_t at; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_syrk_check( alpha, a, beta, c, cntx ); + + bli_obj_alias_to( a, &at ); + bli_obj_toggle_trans( &at ); + + PASTEMAC(gemmt,BLIS_OAPI_EX_SUF)( alpha, a, &at, beta, c, cntx, rntm ); +} + + +void PASTEMAC(trmm,BLIS_OAPI_EX_SUF) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( b ); + ind_t im = BLIS_NAT; + + // If all matrix operands are complex and of the same storage datatype, try + // to get an induced method (if one is available and enabled). + if ( bli_obj_dt( a ) == bli_obj_dt( b ) && + bli_obj_is_complex( b ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_trmmind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_trmm_check( side, alpha, a, b, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_trmm_front( side, alpha, a, b, cntx, rntm, NULL ); +} + + +void PASTEMAC(trsm,BLIS_OAPI_EX_SUF) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Default to using native execution. + num_t dt = bli_obj_dt( b ); + ind_t im = BLIS_NAT; + + // If all matrix operands are complex and of the same storage datatype, try + // to get an induced method (if one is available and enabled). + if ( bli_obj_dt( a ) == bli_obj_dt( b ) && + bli_obj_is_complex( b ) ) + { + // Find the highest priority induced method that is both enabled and + // available for the current operation. (If an induced method is + // available but not enabled, or simply unavailable, BLIS_NAT will + // be returned here.) + im = bli_trsmind_find_avail( dt ); + } + + // If necessary, obtain a valid context from the gks using the induced + // method id determined above. + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_trsm_check( side, alpha, a, b, cntx ); + + // Invoke the operation's front-end and request the default control tree. + bli_trsm_front( side, alpha, a, b, cntx, rntm, NULL ); +} diff --git a/frame/3/bli_l3_oapi_ex.h b/frame/3/bli_l3_oapi_ex.h new file mode 100644 index 0000000000..946a7aa175 --- /dev/null +++ b/frame/3/bli_l3_oapi_ex.h @@ -0,0 +1,113 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype object-based interfaces (expert). +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( gemm ) +GENPROT( gemmt ) +GENPROT( her2k ) +GENPROT( syr2k ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( hemm ) +GENPROT( symm ) +GENPROT( trmm3 ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( herk ) +GENPROT( syrk ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( trmm ) +GENPROT( trsm ) + diff --git a/frame/3/bli_l3_oft.h b/frame/3/bli_l3_oft.h index c182ed56c8..e7c8dcca31 100644 --- a/frame/3/bli_l3_oft.h +++ b/frame/3/bli_l3_oft.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -57,6 +58,7 @@ typedef void (*PASTECH(opname,_oft)) \ ); GENTDEF( gemm ) +GENTDEF( gemmt ) GENTDEF( her2k ) GENTDEF( syr2k ) diff --git a/frame/3/bli_l3_oft_var.h b/frame/3/bli_l3_oft_var.h index 1456f8eff3..ea10d80904 100644 --- a/frame/3/bli_l3_oft_var.h +++ b/frame/3/bli_l3_oft_var.h @@ -54,24 +54,7 @@ typedef void (*PASTECH(opname,_var_oft)) \ thrinfo_t* thread \ ); -GENTDEF( gemm ) - - -#undef GENTDEF -#define GENTDEF( opname ) \ -\ -typedef void (*PASTECH(opname,_var_oft)) \ -( \ - obj_t* a, \ - obj_t* b, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - cntl_t* cntl, \ - thrinfo_t* thread \ -); - -GENTDEF( trsm ) +GENTDEF( l3 ) diff --git a/frame/3/trsm/bli_trsm_packab.c b/frame/3/bli_l3_packab.c similarity index 80% rename from frame/3/trsm/bli_trsm_packab.c rename to frame/3/bli_l3_packab.c index 841230d80d..d911819429 100644 --- a/frame/3/trsm/bli_trsm_packab.c +++ b/frame/3/bli_l3_packab.c @@ -34,7 +34,7 @@ #include "blis.h" -void bli_trsm_packa +void bli_l3_packa ( obj_t* a, obj_t* b, @@ -45,12 +45,19 @@ void bli_trsm_packa thrinfo_t* thread ) { - obj_t a_pack; + obj_t a_local, a_pack; + + bli_obj_alias_to( a, &a_local ); + if ( bli_obj_has_trans( a ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } // Pack matrix A according to the control tree node. - bli_l3_packm + bli_packm_int ( - a, + &a_local, &a_pack, cntx, rntm, @@ -59,7 +66,7 @@ void bli_trsm_packa ); // Proceed with execution using packed matrix A. - bli_trsm_int + bli_l3_int ( &BLIS_ONE, &a_pack, @@ -75,7 +82,7 @@ void bli_trsm_packa // ----------------------------------------------------------------------------- -void bli_trsm_packb +void bli_l3_packb ( obj_t* a, obj_t* b, @@ -86,25 +93,39 @@ void bli_trsm_packb thrinfo_t* thread ) { - obj_t b_pack; + obj_t bt_local, bt_pack; + + // We always pass B^T to bli_l3_packm. + bli_obj_alias_to( b, &bt_local ); + if ( bli_obj_has_trans( b ) ) + { + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &bt_local ); + } + else + { + bli_obj_induce_trans( &bt_local ); + } // Pack matrix B according to the control tree node. - bli_l3_packm + bli_packm_int ( - b, - &b_pack, + &bt_local, + &bt_pack, cntx, rntm, cntl, thread ); + // Transpose packed object back to B. + bli_obj_induce_trans( &bt_pack ); + // Proceed with execution using packed matrix B. - bli_trsm_int + bli_l3_int ( &BLIS_ONE, a, - &b_pack, + &bt_pack, &BLIS_ONE, c, cntx, diff --git a/frame/3/bli_l3_packab.h b/frame/3/bli_l3_packab.h new file mode 100644 index 0000000000..380ca72123 --- /dev/null +++ b/frame/3/bli_l3_packab.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +void bli_l3_packa + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); + +void bli_l3_packb + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); + diff --git a/frame/3/bli_l3_packm.c b/frame/3/bli_l3_packm.c deleted file mode 100644 index bfb066bfb5..0000000000 --- a/frame/3/bli_l3_packm.c +++ /dev/null @@ -1,187 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -void bli_l3_packm - ( - obj_t* x, - obj_t* x_pack, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - packbuf_t pack_buf_type; - mem_t* cntl_mem_p; - siz_t size_needed; - - // FGVZ: Not sure why we need this barrier, but we do. - bli_thread_obarrier( thread ); - - // Every thread initializes x_pack and determines the size of memory - // block needed (which gets embedded into the otherwise "blank" mem_t - // entry in the control tree node). - size_needed - = - bli_packm_init - ( - x, - x_pack, - cntx, - cntl - ); - - // If zero was returned, no memory needs to be allocated and so we can - // return early. - if ( size_needed == 0 ) return; - - // Query the pack buffer type from the control tree node. - pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); - - // Query the address of the mem_t entry within the control tree node. - cntl_mem_p = bli_cntl_pack_mem( cntl ); - - // Check the mem_t field in the control tree. If it is unallocated, then - // we need to acquire a block from the memory broker and broadcast it to - // all threads in the chief's thread group. - if ( bli_mem_is_unalloc( cntl_mem_p ) ) - { - mem_t* local_mem_p; - mem_t local_mem_s; - - if ( bli_thread_am_ochief( thread ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_l3_packm(): acquiring mem pool block\n" ); - #endif - - // The chief thread acquires a block from the memory broker - // and saves the associated mem_t entry to local_mem_s. - bli_membrk_acquire_m - ( - rntm, - size_needed, - pack_buf_type, - &local_mem_s - ); - } - - // Broadcast the address of the chief thread's local mem_t entry to - // all threads. - local_mem_p = bli_thread_obroadcast( thread, &local_mem_s ); - - // Save the contents of the chief thread's local mem_t entry to the - // mem_t field in this thread's control tree node. - *cntl_mem_p = *local_mem_p; - } - else // ( bli_mem_is_alloc( cntl_mem_p ) ) - { - mem_t* local_mem_p; - mem_t local_mem_s; - - // If the mem_t entry in the control tree does NOT contain a NULL - // buffer, then a block has already been acquired from the memory - // broker and cached in the control tree. - - // As a sanity check, we should make sure that the mem_t object isn't - // associated with a block that is too small compared to the size of - // the packed matrix buffer that is needed, according to the return - // value from packm_init(). - siz_t cntl_mem_size = bli_mem_size( cntl_mem_p ); - - if ( cntl_mem_size < size_needed ) - { - if ( bli_thread_am_ochief( thread ) ) - { - // The chief thread releases the existing block associated with - // the mem_t entry in the control tree, and then re-acquires a - // new block, saving the associated mem_t entry to local_mem_s. - bli_membrk_release - ( - rntm, - cntl_mem_p - ); - bli_membrk_acquire_m - ( - rntm, - size_needed, - pack_buf_type, - &local_mem_s - ); - } - - // Broadcast the address of the chief thread's local mem_t entry to - // all threads. - local_mem_p = bli_thread_obroadcast( thread, &local_mem_s ); - - // Save the chief thread's local mem_t entry to the mem_t field in - // this thread's control tree node. - *cntl_mem_p = *local_mem_p; - } - else - { - // If the mem_t entry is already allocated and sufficiently large, - // then we use it as-is. No action is needed, because all threads - // will already have the cached values in their local control - // trees' mem_t entries, currently pointed to by cntl_mem_p. - - bli_thread_obarrier( thread ); - } - } - - - // Update the buffer address in x_pack to point to the buffer associated - // with the mem_t entry acquired from the memory broker (now cached in - // the control tree node). - void* buf = bli_mem_buffer( cntl_mem_p ); - bli_obj_set_buffer( buf, x_pack ); - - - // Pack the contents of object x to object x_pack. - bli_packm_int - ( - x, - x_pack, - cntx, - cntl, - thread - ); - - // Barrier so that packing is done before computation. - bli_thread_obarrier( thread ); -} - diff --git a/frame/3/bli_l3_prune.c b/frame/3/bli_l3_prune.c index fa008fd15e..6ca8244cbb 100644 --- a/frame/3/bli_l3_prune.c +++ b/frame/3/bli_l3_prune.c @@ -47,7 +47,7 @@ void bli_l3_prune_unref_mparts_m opid_t family = bli_cntl_family( cntl ); if ( family == BLIS_GEMM ) return; // No pruning is necessary for gemm. - else if ( family == BLIS_HERK ) bli_herk_prune_unref_mparts_m( a, b, c ); + else if ( family == BLIS_GEMMT ) bli_gemmt_prune_unref_mparts_m( a, b, c ); else if ( family == BLIS_TRMM ) bli_trmm_prune_unref_mparts_m( a, b, c ); else if ( family == BLIS_TRSM ) bli_trsm_prune_unref_mparts_m( a, b, c ); } @@ -68,7 +68,7 @@ void PASTEMAC(l3_prune_unref_mparts_,dim) \ opid_t family = bli_cntl_family( cntl ); \ \ if ( family == BLIS_GEMM ) return; /* No pruning is necessary for gemm. */ \ - else if ( family == BLIS_HERK ) PASTEMAC(herk_prune_unref_mparts_,dim)( a, b, c ); \ + else if ( family == BLIS_GEMMT ) PASTEMAC(gemmt_prune_unref_mparts_,dim)( a, b, c ); \ else if ( family == BLIS_TRMM ) PASTEMAC(trmm_prune_unref_mparts_,dim)( a, b, c ); \ else if ( family == BLIS_TRSM ) PASTEMAC(trsm_prune_unref_mparts_,dim)( a, b, c ); \ } @@ -152,7 +152,7 @@ void PASTEMAC(opname,_prune_unref_mparts_k) \ for the k dimension. */ \ } -GENFRONT( herk ) +GENFRONT( gemmt ) // ----------------------------------------------------------------------------- diff --git a/frame/3/bli_l3_prune.h b/frame/3/bli_l3_prune.h index 340ecd4dbf..ad8f07dc43 100644 --- a/frame/3/bli_l3_prune.h +++ b/frame/3/bli_l3_prune.h @@ -64,9 +64,9 @@ GENPROT( gemm, m ) GENPROT( gemm, n ) GENPROT( gemm, k ) -GENPROT( herk, m ) -GENPROT( herk, n ) -GENPROT( herk, k ) +GENPROT( gemmt, m ) +GENPROT( gemmt, n ) +GENPROT( gemmt, k ) GENPROT( trmm, m ) GENPROT( trmm, n ) diff --git a/frame/3/bli_l3_schema.c b/frame/3/bli_l3_schema.c new file mode 100644 index 0000000000..bde30c5277 --- /dev/null +++ b/frame/3/bli_l3_schema.c @@ -0,0 +1,80 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_l3_set_schemas + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx + ) +{ + // Begin with pack schemas for native execution. + pack_t schema_a = BLIS_PACKED_ROW_PANELS; + pack_t schema_b = BLIS_PACKED_COL_PANELS; + + // When executing the 1m method, choose the appropriate pack schemas based + // on the microkernel preference encoded within the current cntx_t (which + // was presumably returned by the gks). + if ( bli_cntx_method( cntx ) == BLIS_1M ) + { + num_t dt = bli_obj_domain( c ) | bli_obj_comp_prec( c ); + + // Note that bli_cntx_l3_vir_ukr_prefers_cols_dt() will use the real + // projection of dt to query the preference of the corresponding native + // real-domain microkernel. This is what ultimately determines which + // variant of 1m is applicable. + if ( bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ) ) + { + schema_a = BLIS_PACKED_ROW_PANELS_1E; + schema_b = BLIS_PACKED_COL_PANELS_1R; + } + else + { + schema_a = BLIS_PACKED_ROW_PANELS_1R; + schema_b = BLIS_PACKED_COL_PANELS_1E; + } + } + + // Embed the schemas into the objects for A and B. This is a sort of hack + // for communicating the desired pack schemas to bli_gemm_cntl_create() + // (via bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows + // us to subsequently access the schemas from the control tree, which + // hopefully reduces some confusion, particularly in bli_packm_init(). + bli_obj_set_pack_schema( schema_a, a ); + bli_obj_set_pack_schema( schema_b, b ); +} + diff --git a/frame/3/bli_l3_schema.h b/frame/3/bli_l3_schema.h new file mode 100644 index 0000000000..c6a12ce520 --- /dev/null +++ b/frame/3/bli_l3_schema.h @@ -0,0 +1,41 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +void bli_l3_set_schemas + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx + ); diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c new file mode 100644 index 0000000000..72ec405ab0 --- /dev/null +++ b/frame/3/bli_l3_sup.c @@ -0,0 +1,203 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Return early if small matrix handling is disabled at configure-time. + #ifdef BLIS_DISABLE_SUP_HANDLING + return BLIS_FAILURE; + #endif + + // Return early if this is a mixed-datatype computation. + if ( bli_obj_dt( c ) != bli_obj_dt( a ) || + bli_obj_dt( c ) != bli_obj_dt( b ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) return BLIS_FAILURE; + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Return early if a microkernel preference-induced transposition would + // have been performed and shifted the dimensions outside of the space + // of sup-handled problems. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( c, BLIS_GEMM_UKR, cntx ) ) + { + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width_after_trans( a ); + + // Pass in m and n reversed, which simulates a transposition of the + // entire operation pursuant to the microkernel storage preference. + if ( !bli_cntx_l3_sup_thresh_is_met( dt, n, m, k, cntx ) ) + return BLIS_FAILURE; + } + else // ukr_prefers_storage_of( c, ... ) + { + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width_after_trans( a ); + + if ( !bli_cntx_l3_sup_thresh_is_met( dt, m, n, k, cntx ) ) + return BLIS_FAILURE; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + +#if 0 +const num_t dt = bli_obj_dt( c ); +const dim_t m = bli_obj_length( c ); +const dim_t n = bli_obj_width( c ); +const dim_t k = bli_obj_width_after_trans( a ); +const dim_t tm = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); +const dim_t tn = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); +const dim_t tk = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + +printf( "dims: %d %d %d (threshs: %d %d %d)\n", + (int)m, (int)n, (int)k, (int)tm, (int)tn, (int)tk ); +#endif + + // We've now ruled out the following two possibilities: + // - the ukernel prefers the operation as-is, and the sup thresholds are + // unsatisfied. + // - the ukernel prefers a transposed operation, and the sup thresholds are + // unsatisfied after taking into account the transposition. + // This implies that the sup thresholds (at least one of them) are met. + // and the small/unpacked handler should be called. + // NOTE: The sup handler is free to enforce a stricter threshold regime + // if it so chooses, in which case it can/should return BLIS_FAILURE. + + // Query the small/unpacked handler from the context and invoke it. + gemmsup_oft gemmsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMM, cntx ); + + return + gemmsup_fp + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm + ); +} + + +err_t bli_gemmtsup + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Return early if small matrix handling is disabled at configure-time. + #ifdef BLIS_DISABLE_SUP_HANDLING + return BLIS_FAILURE; + #endif + + // Return early if this is a mixed-datatype computation. + if ( bli_obj_dt( c ) != bli_obj_dt( a ) || + bli_obj_dt( c ) != bli_obj_dt( b ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) return BLIS_FAILURE; + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Return early if the problem dimensions exceed their sup thresholds. + // Notice that we do not bother to check whether the microkernel + // prefers or dislikes the storage of C, since the same check is called + // for either way. + { + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t k = bli_obj_width_after_trans( a ); + + if ( !bli_cntx_l3_sup_thresh_is_met( dt, m, m, k, cntx ) ) + return BLIS_FAILURE; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // We've now ruled out the possibility that the sup thresholds are + // unsatisfied. + // This implies that the sup thresholds (at least one of them) are met. + // and the small/unpacked handler should be called. + // NOTE: The sup handler is free to enforce a stricter threshold regime + // if it so chooses, in which case it can/should return BLIS_FAILURE. + + // Query the small/unpacked handler from the context and invoke it. + gemmtsup_oft gemmtsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMMT, cntx ); + + return + gemmtsup_fp + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm + ); +} + + diff --git a/frame/3/bli_l3_sup.h b/frame/3/bli_l3_sup.h new file mode 100644 index 0000000000..fe6d0483e7 --- /dev/null +++ b/frame/3/bli_l3_sup.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +err_t bli_gemmsup + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +err_t bli_gemmtsup + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + diff --git a/frame/3/bli_l3_sup_ft_ker.h b/frame/3/bli_l3_sup_ft_ker.h new file mode 100644 index 0000000000..5bb2218f3b --- /dev/null +++ b/frame/3/bli_l3_sup_ft_ker.h @@ -0,0 +1,68 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_FT_KER_H +#define BLIS_L3_SUP_FT_KER_H + + +// +// -- Level-3 small/unpacked kernel function types ----------------------------- +// + +// gemmsup + +#undef GENTDEF +#define GENTDEF( ctype, ch, opname, tsuf ) \ +\ +typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ); + +INSERT_GENTDEF( gemmsup ) + + +#endif + diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c new file mode 100644 index 0000000000..e54e01d7c7 --- /dev/null +++ b/frame/3/bli_l3_sup_int.c @@ -0,0 +1,420 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ +#if 0 + //bli_gemmsup_ref_var2 + //bli_gemmsup_ref_var1 + #if 0 + bli_gemmsup_ref_var1n + #else + #endif + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + if ( is_rrr_rrc_rcr_crr ) + { + bli_gemmsup_ref_var2m + ( + BLIS_NO_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm + ); + } + else + { + bli_gemmsup_ref_var2m + ( + BLIS_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm + ); + } + + return BLIS_SUCCESS; +#endif + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + + const num_t dt = bli_obj_dt( c ); + const bool row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + + const bool is_primary = ( row_pref ? is_rrr_rrc_rcr_crr + : is_rcc_crc_ccr_ccc ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + bool use_bp = TRUE; + dim_t jc_new; + dim_t ic_new; + + + if ( is_primary ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = FALSE; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + // block-panel macrokernel; m -> mc, mr; n -> nc, nr: var2() + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n primary\n" ); + #endif + // panel-block macrokernel; m -> nc*,mr; n -> mc*,nr: var1() + bli_gemmsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of nc up to be a multiple of mr. + } + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = FALSE; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m non-primary\n" ); + #endif + // panel-block macrokernel; m -> nc, nr; n -> mc, mr: var2() + trans + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n non-primary\n" ); + #endif + // block-panel macrokernel; m -> mc*,nr; n -> nc*,mr: var1() + trans + bli_gemmsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of mc up to be a multiple of nr. + } + } + + // Return success so that the caller knows that we computed the solution. + return BLIS_SUCCESS; +} + +// ----------------------------------------------------------------------------- + +err_t bli_gemmtsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + + const num_t dt = bli_obj_dt( c ); + const bool row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + + const bool is_primary = ( row_pref ? is_rrr_rrc_rcr_crr + : is_rcc_crc_ccr_ccc ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = m; + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + bool use_bp = TRUE; + dim_t jc_new; + dim_t ic_new; + + + if ( is_primary ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = FALSE; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + // block-panel macrokernel; m -> mc, mr; n -> nc, nr: var2() +#if 0 + bli_gemmtsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); +#endif + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n primary\n" ); + #endif + // panel-block macrokernel; m -> nc*,mr; n -> mc*,nr: var1() +#if 0 + bli_gemmtsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); +#endif + // *requires nudging of nc up to be a multiple of mr. + } + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = FALSE; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m non-primary\n" ); + #endif + // panel-block macrokernel; m -> nc, nr; n -> mc, mr: var2() + trans +#if 0 + bli_gemmtsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); +#endif + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n non-primary\n" ); + #endif + // block-panel macrokernel; m -> mc*,nr; n -> nc*,mr: var1() + trans +#if 0 + bli_gemmtsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); +#endif + // *requires nudging of mc up to be a multiple of nr. + } + } + + // Return success so that the caller knows that we computed the solution. + return BLIS_SUCCESS; +} + diff --git a/frame/3/bli_l3_sup_int.h b/frame/3/bli_l3_sup_int.h new file mode 100644 index 0000000000..c6cb88056e --- /dev/null +++ b/frame/3/bli_l3_sup_int.h @@ -0,0 +1,57 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2000, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +err_t bli_gemmsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +err_t bli_gemmtsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); diff --git a/frame/3/bli_l3_sup_ker.h b/frame/3/bli_l3_sup_ker.h new file mode 100644 index 0000000000..6c77fffe04 --- /dev/null +++ b/frame/3/bli_l3_sup_ker.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Define template prototypes for level-3 kernels on small/unpacked matrices. +// + +// Note: Instead of defining function prototype macro templates and then +// instantiating those macros to define the individual function prototypes, +// we simply alias the official operations' prototypes as defined in +// bli_l3_ker_prot.h. + +#undef GENTPROT +#define GENTPROT GEMMSUP_KER_PROT + +INSERT_GENTPROT_BASIC0( gemmsup_rv_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_rg_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_cv_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_cg_ukr_name ) + +INSERT_GENTPROT_BASIC0( gemmsup_rd_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_cd_ukr_name ) + +INSERT_GENTPROT_BASIC0( gemmsup_gx_ukr_name ) + diff --git a/frame/3/bli_l3_sup_ker_prot.h b/frame/3/bli_l3_sup_ker_prot.h new file mode 100644 index 0000000000..899a47d3fa --- /dev/null +++ b/frame/3/bli_l3_sup_ker_prot.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Define template prototypes for level-3 kernels on small/unpacked matrices. +// + +#define GEMMSUP_KER_PROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ); + diff --git a/frame/3/bli_l3_sup_oft.h b/frame/3/bli_l3_sup_oft.h new file mode 100644 index 0000000000..98a06cf57e --- /dev/null +++ b/frame/3/bli_l3_sup_oft.h @@ -0,0 +1,62 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019-20, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_OFT_H +#define BLIS_L3_SUP_OFT_H + + +// +// -- Level-3 small/unpacked object function types ----------------------------- +// + +// gemm + +#undef GENTDEF +#define GENTDEF( opname ) \ +\ +typedef err_t (*PASTECH(opname,_oft)) \ +( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ +); + +GENTDEF( gemmsup ) +GENTDEF( gemmtsup ) +#endif + diff --git a/frame/3/bli_l3_sup_packm_a.c b/frame/3/bli_l3_sup_packm_a.c new file mode 100644 index 0000000000..56726c5f8c --- /dev/null +++ b/frame/3/bli_l3_sup_packm_a.c @@ -0,0 +1,430 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Inspect whether we are going to be packing matrix A. */ \ + if ( will_pack == FALSE ) \ + { \ + } \ + else /* if ( will_pack == TRUE ) */ \ + { \ + /* NOTE: This "rounding up" of the last upanel is actually optional + for the rrc/crc cases, but absolutely necessary for the other cases + since we NEED that last micropanel to have the same ldim (cs_p) as + the other micropanels. Why? So that millikernels can use the same + upanel ldim for all iterations of the ir loop. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin + the packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the memory broker. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was + passed in. It needs to be that mem_t struct, and not a + local (temporary) mem_t, since there is no barrier until + after packing is finished, which could allow a race + condition whereby the chief thread exits the current + function before the other threads have a chance to copy + from it. (A barrier would fix that race condition, but + then again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the memory + broker and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool did_pack, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Inspect whether we previously packed matrix A. */ \ + if ( did_pack == FALSE ) \ + { \ + /* If we didn't pack matrix A, there's nothing to be done. */ \ + } \ + else /* if ( did_pack == TRUE ) */ \ + { \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + stor3_t stor_id, \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Inspect whether we are going to be packing matrix A. */ \ + if ( will_pack == FALSE ) \ + { \ + *m_max = m; \ + *k_max = k; \ +\ + /* Set the parameters for use with no packing of A (ie: using the + source matrix A directly). */ \ + { \ + /* Use the strides of the source matrix as the final values. */ \ + *rs_p = rs_x; \ + *cs_p = cs_x; \ +\ + *pd_p = mr; \ + *ps_p = mr * rs_x; \ +\ + /* Set the schema to "not packed" to indicate that packing will be + skipped. */ \ + *schema = BLIS_NOT_PACKED; \ + } \ +\ + /* Since we won't be packing, simply update the buffer address provided + by the caller to point to source matrix. */ \ + *p = x; \ + } \ + else /* if ( will_pack == TRUE ) */ \ + { \ + /* NOTE: This is "rounding up" of the last upanel is actually optional + for the rrc/crc cases, but absolutely necessary for the other cases + since we NEED that last micropanel to have the same ldim (cs_p) as + the other micropanels. Why? So that millikernels can use the same + upanel ldim for all iterations of the ir loop. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) \ + { \ + /* stor3_t id values _RRC and _CRC: pack A to plain row storage. */ \ + *rs_p = k; \ + *cs_p = 1; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "row packed" to indicate packing to plain + row storage. */ \ + *schema = BLIS_PACKED_ROWS; \ + } \ + else \ + { \ + /* All other stor3_t ids: pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the + memory associated with the mem_t entry acquired from the memory + broker. */ \ + *p = bli_mem_buffer( mem ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + stor3_t stor_id, \ + trans_t transc, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. If packing is not requested, + this function will reduce to a no-op. */ \ + PASTEMAC(ch,packm_sup_init_mem_a) \ + ( \ + will_pack, \ + pack_buf_type, \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. If A + will not be packed, then a_use will be set to point to a and the _a_use + strides will be set accordingly. */ \ + PASTEMAC(ch,packm_sup_init_a) \ + ( \ + will_pack, \ + stor_id, \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + a, rs_a, cs_a, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + cntx, \ + mem, \ + thread \ + ); \ +\ + /* Inspect whether we are going to be packing matrix A. */ \ + if ( will_pack == FALSE ) \ + { \ + /* If we aren't going to pack matrix A, then there's nothing to do. */ \ +\ + /* + printf( "blis_ packm_sup_a: not packing A.\n" ); \ + */ \ + } \ + else /* if ( will_pack == TRUE ) */ \ + { \ + if ( schema == BLIS_PACKED_ROWS ) \ + { \ + /* + printf( "blis_ packm_sup_a: packing A to rows.\n" ); \ + */ \ +\ + /* For plain packing by rows, use var2. */ \ + PASTEMAC(ch,packm_sup_var2) \ + ( \ + transc, \ + schema, \ + m, \ + k, \ + kappa, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + cntx, \ + thread \ + ); \ + } \ + else /* if ( schema == BLIS_PACKED_ROW_PANELS ) */ \ + { \ + /* + printf( "blis_ packm_sup_a: packing A to row panels.\n" ); \ + */ \ +\ + /* For packing to column-stored row panels, use var1. */ \ + PASTEMAC(ch,packm_sup_var1) \ + ( \ + transc, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ + } \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_a ) + diff --git a/frame/3/bli_l3_sup_packm_a.h b/frame/3/bli_l3_sup_packm_a.h new file mode 100644 index 0000000000..95c9582e79 --- /dev/null +++ b/frame/3/bli_l3_sup_packm_a.h @@ -0,0 +1,118 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool did_pack, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + stor3_t stor_id, \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + stor3_t stor_id, \ + trans_t transc, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_a ) + diff --git a/frame/3/bli_l3_sup_packm_b.c b/frame/3/bli_l3_sup_packm_b.c new file mode 100644 index 0000000000..32c14afe3d --- /dev/null +++ b/frame/3/bli_l3_sup_packm_b.c @@ -0,0 +1,430 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Inspect whether we are going to be packing matrix B. */ \ + if ( will_pack == FALSE ) \ + { \ + } \ + else /* if ( will_pack == TRUE ) */ \ + { \ + /* NOTE: This "rounding up" of the last upanel is actually optional + for the rrc/crc cases, but absolutely necessary for the other cases + since we NEED that last micropanel to have the same ldim (cs_p) as + the other micropanels. Why? So that millikernels can use the same + upanel ldim for all iterations of the ir loop. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin + the packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the memory broker. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was + passed in. It needs to be that mem_t struct, and not a + local (temporary) mem_t, since there is no barrier until + after packing is finished, which could allow a race + condition whereby the chief thread exits the current + function before the other threads have a chance to copy + from it. (A barrier would fix that race condition, but + then again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the memory + broker and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool did_pack, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Inspect whether we previously packed matrix A. */ \ + if ( did_pack == FALSE ) \ + { \ + /* If we didn't pack matrix A, there's nothing to be done. */ \ + } \ + else /* if ( did_pack == TRUE ) */ \ + { \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + stor3_t stor_id, \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Inspect whether we are going to be packing matrix B. */ \ + if ( will_pack == FALSE ) \ + { \ + *k_max = k; \ + *n_max = n; \ +\ + /* Set the parameters for use with no packing of B (ie: using the + source matrix B directly). */ \ + { \ + /* Use the strides of the source matrix as the final values. */ \ + *rs_p = rs_x; \ + *cs_p = cs_x; \ +\ + *pd_p = nr; \ + *ps_p = nr * cs_x; \ +\ + /* Set the schema to "not packed" to indicate that packing will be + skipped. */ \ + *schema = BLIS_NOT_PACKED; \ + } \ +\ + /* Since we won't be packing, simply update the buffer address provided + by the caller to point to source matrix. */ \ + *p = x; \ + } \ + else /* if ( will_pack == TRUE ) */ \ + { \ + /* NOTE: This is "rounding up" of the last upanel is actually optional + for the rrc/crc cases, but absolutely necessary for the other cases + since we NEED that last micropanel to have the same ldim (cs_p) as + the other micropanels. Why? So that millikernels can use the same + upanel ldim for all iterations of the ir loop. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) \ + { \ + /* stor3_t id values _RRC and _CRC: pack B to plain row storage. */ \ + *rs_p = 1; \ + *cs_p = k; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "column packed" to indicate packing to plain + column storage. */ \ + *schema = BLIS_PACKED_COLUMNS; \ + } \ + else \ + { \ + /* All other stor3_t ids: pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the + memory associated with the mem_t entry acquired from the memory + broker. */ \ + *p = bli_mem_buffer( mem ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + stor3_t stor_id, \ + trans_t transc, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. If packing is not requested, + this function will reduce to a no-op. */ \ + PASTEMAC(ch,packm_sup_init_mem_b) \ + ( \ + will_pack, \ + pack_buf_type, \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. If B + will not be packed, then b_use will be set to point to b and the _b_use + strides will be set accordingly. */ \ + PASTEMAC(ch,packm_sup_init_b) \ + ( \ + will_pack, \ + stor_id, \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + b, rs_b, cs_b, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + cntx, \ + mem, \ + thread \ + ); \ +\ + /* Inspect whether we are going to be packing matrix B. */ \ + if ( will_pack == FALSE ) \ + { \ + /* If we aren't going to pack matrix B, then there's nothing to do. */ \ +\ + /* + printf( "blis_ packm_sup_b: not packing B.\n" ); \ + */ \ + } \ + else /* if ( will_pack == TRUE ) */ \ + { \ + if ( schema == BLIS_PACKED_COLUMNS ) \ + { \ + /* + printf( "blis_ packm_sup_b: packing B to columns.\n" ); \ + */ \ +\ + /* For plain packing by columns, use var2. */ \ + PASTEMAC(ch,packm_sup_var2) \ + ( \ + transc, \ + schema, \ + k, \ + n, \ + kappa, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + cntx, \ + thread \ + ); \ + } \ + else /* if ( schema == BLIS_PACKED_COL_PANELS ) */ \ + { \ + /* + printf( "blis_ packm_sup_b: packing B to col panels.\n" ); \ + */ \ +\ + /* For packing to row-stored column panels, use var1. */ \ + PASTEMAC(ch,packm_sup_var1) \ + ( \ + transc, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ + } \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_sup_b ) + diff --git a/frame/3/bli_l3_sup_packm_b.h b/frame/3/bli_l3_sup_packm_b.h new file mode 100644 index 0000000000..2965727d54 --- /dev/null +++ b/frame/3/bli_l3_sup_packm_b.h @@ -0,0 +1,118 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool did_pack, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + stor3_t stor_id, \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + bool will_pack, \ + packbuf_t pack_buf_type, \ + stor3_t stor_id, \ + trans_t transc, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +INSERT_GENTPROT_BASIC0( packm_sup_b ) + diff --git a/frame/3/bli_l3_sup_packm_var.c b/frame/3/bli_l3_sup_packm_var.c new file mode 100644 index 0000000000..85fb246f01 --- /dev/null +++ b/frame/3/bli_l3_sup_packm_var.c @@ -0,0 +1,440 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-like interfaces to the variants. +// + +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len_full; \ + dim_t panel_len_i; \ + dim_t panel_len_max; \ + dim_t panel_len_max_i; \ + dim_t panel_dim_i; \ + dim_t panel_dim_max; \ + inc_t vs_c; \ + inc_t ldc; \ + inc_t ldp, p_inc; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* If c needs a transposition, induce it so that we can more simply + express the remaining parameters and code. */ \ + if ( bli_does_trans( transc ) ) \ + { \ + bli_swap_incs( &rs_c, &cs_c ); \ + bli_toggle_trans( &transc ); \ + } \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len_full = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + vs_c = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len_full = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + vs_c = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*vs_c; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + { \ + panel_len_i = panel_len_full; \ + panel_len_max_i = panel_len_max; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTEMAC(ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim_i, \ + panel_dim_max, \ + panel_len_i, \ + panel_len_max_i, \ + kappa_cast, \ + c_use, vs_c, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ + /* NOTE: This value is equivalent to ps_p. */ \ + p_inc = ps_p; \ + } \ +\ + p_begin += p_inc; \ +\ +/* +if ( row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ + } \ +\ +} + +INSERT_GENTFUNCR_BASIC( packm, packm_sup_var1 ) + + + +/* +if ( row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var2: b", m, n, \ + c_cast, rs_c, cs_c, "%4.1f", "" ); \ +if ( col_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var2: a", m, n, \ + c_cast, rs_c, cs_c, "%4.1f", "" ); \ +*/ +/* +if ( row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: b packed", *m_panel_max, *n_panel_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: a packed", *m_panel_max, *n_panel_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ +/* +if ( col_stored ) { \ + if ( bli_thread_work_id( thread ) == 0 ) \ + { \ + printf( "packm_blk_var1: thread %lu (a = %p, ap = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ + fflush( stdout ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: a", *m_panel_use, *n_panel_use, \ + ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: ap", *m_panel_max, *n_panel_max, \ + ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ + fflush( stdout ); \ + } \ +bli_thread_barrier( thread ); \ + if ( bli_thread_work_id( thread ) == 1 ) \ + { \ + printf( "packm_blk_var1: thread %lu (a = %p, ap = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ + fflush( stdout ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: a", *m_panel_use, *n_panel_use, \ + ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: ap", *m_panel_max, *n_panel_max, \ + ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ + fflush( stdout ); \ + } \ +bli_thread_barrier( thread ); \ +} \ +else { \ + if ( bli_thread_work_id( thread ) == 0 ) \ + { \ + printf( "packm_blk_var1: thread %lu (b = %p, bp = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ + fflush( stdout ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: b", *m_panel_use, *n_panel_use, \ + ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: bp", *m_panel_max, *n_panel_max, \ + ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ + fflush( stdout ); \ + } \ +bli_thread_barrier( thread ); \ + if ( bli_thread_work_id( thread ) == 1 ) \ + { \ + printf( "packm_blk_var1: thread %lu (b = %p, bp = %p)\n", bli_thread_work_id( thread ), c_use, p_use ); \ + fflush( stdout ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: b", *m_panel_use, *n_panel_use, \ + ( ctype* )c_use, rs_c, cs_c, "%4.1f", "" ); \ + PASTEMAC(ch,fprintm)( stdout, "packm_blk_var1: bp", *m_panel_max, *n_panel_max, \ + ( ctype* )p_use, rs_p, cs_p, "%4.1f", "" ); \ + fflush( stdout ); \ + } \ +bli_thread_barrier( thread ); \ +} \ +*/ +/* + PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_rpi", *m_panel_max, *n_panel_max, \ + ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ +*/ +/* + if ( row_stored ) { \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_r", *m_panel_max, *n_panel_max, \ + ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_i", *m_panel_max, *n_panel_max, \ + (( ctype_r* )c_use)+rs_c, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_r", *m_panel_max, *n_panel_max, \ + ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ + inc_t is_b = rs_p * *m_panel_max; \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_i", *m_panel_max, *n_panel_max, \ + ( ctype_r* )p_use + is_b, rs_p, cs_p, "%4.1f", "" ); \ + } \ +*/ +/* + if ( col_stored ) { \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_r", *m_panel_max, *n_panel_max, \ + ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_i", *m_panel_max, *n_panel_max, \ + (( ctype_r* )c_use)+rs_c, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_r", *m_panel_max, *n_panel_max, \ + ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ + PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_i", *m_panel_max, *n_panel_max, \ + ( ctype_r* )p_use + p_inc, rs_p, cs_p, "%4.1f", "" ); \ + } \ +*/ + +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it; \ + dim_t vector_len; \ + inc_t incc, ldc; \ + inc_t incp, ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* If c needs a transposition, induce it so that we can more simply + express the remaining parameters and code. */ \ + if ( bli_does_trans( transc ) ) \ + { \ + bli_swap_incs( &rs_c, &cs_c ); \ + bli_toggle_trans( &transc ); \ + } \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool col_stored = bli_is_col_packed( schema ); \ + /*bool row_stored = bli_is_row_packed( schema );*/ \ +\ + if ( col_stored ) \ + { \ + /* Prepare to pack to a column-stored matrix. */ \ + iter_dim = n; \ + vector_len = m; \ + incc = rs_c; \ + ldc = cs_c; \ + incp = 1; \ + ldp = cs_p; \ + } \ + else /* if ( row_stored ) */ \ + { \ + /* Prepare to pack to a row-stored matrix. */ \ + iter_dim = m; \ + vector_len = n; \ + incc = cs_c; \ + ldc = rs_c; \ + incp = 1; \ + ldp = rs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim; \ +\ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( it = 0; it < n_iter; it += 1 ) \ + { \ + ctype* restrict c_begin = c_cast + (it )*ldc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + { \ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTEMAC2(ch,scal2v,BLIS_TAPI_EX_SUF) \ + ( \ + conjc, \ + vector_len, \ + kappa_cast, \ + c_use, incc, \ + p_use, incp, \ + cntx, \ + NULL \ + ); \ + } \ +\ + } \ +\ + p_begin += ldp; \ +\ +/* +if ( row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ + } \ +} + +INSERT_GENTFUNCR_BASIC( packm, packm_sup_var2 ) + diff --git a/frame/3/bli_l3_sup_packm_var.h b/frame/3/bli_l3_sup_packm_var.h new file mode 100644 index 0000000000..5ccdd3b762 --- /dev/null +++ b/frame/3/bli_l3_sup_packm_var.h @@ -0,0 +1,78 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +INSERT_GENTPROT_BASIC0( packm_sup_var1 ) + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +INSERT_GENTPROT_BASIC0( packm_sup_var2 ) + diff --git a/frame/3/bli_l3_sup_ref.c b/frame/3/bli_l3_sup_ref.c new file mode 100644 index 0000000000..f03ec1b18f --- /dev/null +++ b/frame/3/bli_l3_sup_ref.c @@ -0,0 +1,188 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup_ref + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // This function implements the default gemmsup handler. If you are a + // BLIS developer and wish to use a different gemmsup handler, please + // register a different function pointer in the context in your + // sub-configuration's bli_cntx_init_*() function. + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + +#if 0 + // NOTE: This special case handling is done within the variants. + + // If alpha is zero, scale by beta and return. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + { + bli_scalm( beta, c ); + return; + } + + // If A or B has a zero dimension, scale C by beta and return early. + if ( bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return BLIS_SUCCESS; + } +#endif + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. NOTE: We check for this here, in bli_gemmsup_ref() + // (and not in the calling function, bli_gemmsup()), because we consider + // this way of handling general stride to be part of the implementation + // and not necessarily a general-purpose solution that would apply to all + // possible gemmsup handlers. Similarly, we check for it here (and not in + // the internal thread entry point, bli_gemmsup_int()) because we don't + // want to have to manage the multiple return values from the threads, + // which we would have to process into a single return value and then + // return from the parallel/threaded region. + if ( stor_id == BLIS_XXX ) return BLIS_FAILURE; + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop. + bli_rntm_set_ways_from_rntm_sup + ( + bli_obj_length( c ), + bli_obj_width( c ), + bli_obj_width( a ), + rntm + ); + +#if 0 + printf( "rntm.pack_a = %d\n", ( int )bli_rntm_pack_a( rntm ) ); + printf( "rntm.pack_b = %d\n", ( int )bli_rntm_pack_b( rntm ) ); + + //bli_rntm_set_pack_a( 0, rntm ); + //bli_rntm_set_pack_b( 0, rntm ); +#endif + + return + bli_l3_sup_thread_decorator + ( + bli_gemmsup_int, + BLIS_GEMM, // operation family id + alpha, + a, + b, + beta, + c, + cntx, + rntm + ); +} + +// ----------------------------------------------------------------------------- + +err_t bli_gemmtsup_ref + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // This function implements the default gemmtsup handler. If you are a + // BLIS developer and wish to use a different gemmtsup handler, please + // register a different function pointer in the context in your + // sub-configuration's bli_cntx_init_*() function. + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemmt_check( alpha, a, b, beta, c, cntx ); + +#if 0 + // NOTE: This special case handling is done within the variants. + + // If alpha is zero, scale by beta and return. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + { + bli_scalm( beta, c ); + return; + } + + // If A or B has a zero dimension, scale C by beta and return early. + if ( bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return BLIS_SUCCESS; + } +#endif + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop. + bli_rntm_set_ways_from_rntm_sup + ( + bli_obj_length( c ), + bli_obj_width( c ), + bli_obj_width( a ), + rntm + ); + + return + bli_l3_sup_thread_decorator + ( + bli_gemmtsup_int, + BLIS_GEMMT, // operation family id + alpha, + a, + b, + beta, + c, + cntx, + rntm + ); +} + diff --git a/frame/3/bli_l3_sup_ref.h b/frame/3/bli_l3_sup_ref.h new file mode 100644 index 0000000000..bce4e1729e --- /dev/null +++ b/frame/3/bli_l3_sup_ref.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2000, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +err_t bli_gemmsup_ref + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +err_t bli_gemmtsup_ref + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + diff --git a/frame/3/bli_l3_sup_var12.c b/frame/3/bli_l3_sup_var12.c new file mode 100644 index 0000000000..106ad86e4d --- /dev/null +++ b/frame/3/bli_l3_sup_var12.c @@ -0,0 +1,735 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmsup_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t eff_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm + ); + +#if 0 +// +// -- var2 --------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var2,gemmsup_ref_var2); + +void bli_gemmsup_ref_var2 + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var2[dt_exec]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ) \ +{ \ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* If alpha is zero, scale by beta and return. */ \ + if ( PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c * NC; \ + const inc_t jcstep_b = cs_b * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = rs_c * MC; \ + const inc_t icstep_a = rs_a * MC; \ +\ + const inc_t jrstep_c = cs_c * NR; \ + const inc_t jrstep_b = cs_b * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ + const inc_t irstep_a = rs_a * MR; \ +\ + /* Query a stor3_t enum value to characterize the problem. + Examples: BLIS_RRR, BLIS_RRC, BLIS_RCR, BLIS_RCC, etc. + NOTE: If any matrix is general-stored, we use the all-purpose sup + microkernel corresponding to the stor3_t enum value BLIS_XXX. */ \ + const stor3_t stor_id = bli_stor3_from_strides( rs_c, cs_c, \ + rs_a, cs_a, rs_b, cs_b ); \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = n / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( n + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( n + NC - 1 ) / NC; \ + const dim_t jc_left = n % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( m + MC - 1 ) / MC; \ + const dim_t ic_left = m % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + const dim_t ir_inc = 1; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + const dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + const dim_t jr_left = nc_cur % NR; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + const dim_t ir_left = mc_cur % MR; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc + j * jrstep_b; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ +/* + ctype* restrict b2 = b_jr; \ +*/ \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = 0; i < ir_iter; i += ir_inc ) \ + { \ + const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic + i * irstep_a; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ +/* + ctype* restrict a2 = bli_gemm_get_next_a_upanel( a_ir, irstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, ir_iter, 0, 1 ) ) \ + { \ + a2 = a_00; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, jrstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, jr_iter, 0, 1 ) ) \ + b2 = b_00; \ + } \ +\ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +*/ \ +\ + /* Invoke the gemmsup micro-kernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a, cs_a, \ + b_jr, rs_b, cs_b, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var2 ) + + +// +// -- var1 --------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var1,gemmsup_ref_var1); + +void bli_gemmsup_ref_var1 + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var1[dt_exec]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ) \ +{ \ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* If alpha is zero, scale by beta and return. */ \ + if ( PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ + const dim_t MC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ +\ + /* Nudge NC up to a multiple of MR and MC up to a multiple of NR. */ \ + const dim_t NC = bli_align_dim_to_mult( NC0, MR ); \ + const dim_t MC = bli_align_dim_to_mult( MC0, NR ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = rs_c * NC; \ + const inc_t jcstep_a = rs_a * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = cs_c * MC; \ + const inc_t icstep_b = cs_b * MC; \ +\ + const inc_t jrstep_c = rs_c * MR; \ + const inc_t jrstep_a = rs_a * MR; \ +\ + const inc_t irstep_c = cs_c * NR; \ + const inc_t irstep_b = cs_b * NR; \ +\ + /* Query a stor3_t enum value to characterize the problem. + Examples: BLIS_RRR, BLIS_RRC, BLIS_RCR, BLIS_RCC, etc. + NOTE: If any matrix is general-stored, we use the all-purpose sup + microkernel corresponding to the stor3_t enum value BLIS_XXX. */ \ + const stor3_t stor_id = bli_stor3_from_strides( rs_c, cs_c, \ + rs_a, cs_a, rs_b, cs_b ); \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = m / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( m + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( m + NC - 1 ) / NC; \ + const dim_t jc_left = m % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( n + MC - 1 ) / MC; \ + const dim_t ic_left = n % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + const dim_t ir_inc = 1; \ +\ + /* Loop over the m dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict a_jc = a_00 + jj * jcstep_a; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + const dim_t jr_iter = ( nc_cur + MR - 1 ) / MR; \ + const dim_t jr_left = nc_cur % MR; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_jc + pp * pcstep_a; \ + ctype* restrict b_pc = b_00 + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the n dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict b_ic = b_pc + ii * icstep_b; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + const dim_t ir_iter = ( mc_cur + NR - 1 ) / NR; \ + const dim_t ir_left = mc_cur % NR; \ +\ + /* Loop over the m dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict a_jr = a_pc + j * jrstep_a; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Loop over the n dimension (MR rows at a time). */ \ + for ( dim_t i = 0; i < ir_iter; i += ir_inc ) \ + { \ + const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict b_ir = b_ic + i * irstep_b; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + /* Invoke the gemmsup micro-kernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_jr, rs_a, cs_a, \ + b_ir, rs_b, cs_b, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var1 ) +#endif + + diff --git a/frame/3/bli_l3_sup_var1n2m.c b/frame/3/bli_l3_sup_var1n2m.c new file mode 100644 index 0000000000..acc4c30712 --- /dev/null +++ b/frame/3/bli_l3_sup_var1n2m.c @@ -0,0 +1,1323 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmsup_fp + +typedef void (*FUNCPTR_T) + ( + bool packa, + bool packb, + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t eff_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- var1n -------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var1n,gemmsup_ref_var1n); + +void bli_gemmsup_ref_var1n + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + +#else + const num_t dt = bli_obj_dt( c ); + + const bool packa = bli_rntm_pack_a( rntm ); + const bool packb = bli_rntm_pack_b( rntm ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var1n[dt]; + +#if 1 + // Optimize some storage/packing cases by transforming them into others. + // These optimizations are expressed by changing trans and/or eff_id. + bli_gemmsup_ref_var1n2m_opt_cases( dt, &trans, packa, packb, &eff_id, cntx ); +#endif + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + packa, + packb, + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm, + thread + ); + } + else + { + // Invoke the function (transposing the operation). + f + ( + packb, + packa, + conjb, // swap the conj values. + conja, + n, // swap the m and n dimensions. + m, + k, + buf_alpha, + buf_b, cs_b, rs_b, // swap the positions of A and B. + buf_a, cs_a, rs_a, // swap the strides of A and B. + buf_beta, + buf_c, cs_c, rs_c, // swap the strides of C. + bli_stor3_trans( eff_id ), // transpose the stor3_t id. + cntx, + rntm, + thread + ); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + bool packa, \ + bool packb, \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ + return; \ + } \ +\ + /* This transposition of the stor3_t id value is inherent to variant 1. + The reason: we assume that variant 2 is the "main" variant. The + consequence of this is that we assume that the millikernels that + iterate over m are registered to the "primary" kernel group associated + with the kernel IO preference; similarly, mkernels that iterate over + n are assumed to be registered to the "non-primary" group associated + with the ("non-primary") anti-preference. Note that this pattern holds + regardless of whether the mkernel set has a row or column preference.) + See bli_l3_sup_int.c for a higher-level view of how this choice is made. */ \ + stor_id = bli_stor3_trans( stor_id ); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + dim_t KC; \ + if ( packa && packb ) \ + { \ + KC = KC0; \ + } \ + else if ( packb ) \ + { \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( stor_id == BLIS_RCR || \ + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; \ + else KC = KC0; \ + } \ + else if ( packa ) \ + { \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = (( KC0 / 2 ) / 2 ) * 2; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( stor_id == BLIS_RCR || \ + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; \ + else KC = KC0; \ + } \ + else /* if ( !packa && !packb ) */ \ + { \ + if ( FALSE ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( m <= MR && n <= NR ) KC = KC0; \ + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ + else KC = (( KC0 / 5 ) / 4 ) * 4; \ + } \ +\ + /* Nudge NC up to a multiple of MR and MC up to a multiple of NR. + NOTE: This is unique to variant 1 (ie: not performed in variant 2) + because MC % MR == 0 and NC % NR == 0 is already enforced at runtime. */ \ + const dim_t NC = bli_align_dim_to_mult( NC0, MR ); \ + const dim_t MC = bli_align_dim_to_mult( MC0, NR ); \ +\ + /* Query the maximum blocksize for MR, which implies a maximum blocksize + extension for the final iteration. */ \ + const dim_t MRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_MR, cntx ); \ + const dim_t MRE = MRM - MR; \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = rs_c; \ + const inc_t jcstep_a = rs_a; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = cs_c; \ + const inc_t icstep_b = cs_b; \ +\ + const inc_t jrstep_c = rs_c * MR; \ +\ + /* + const inc_t jrstep_a = rs_a * MR; \ +\ + const inc_t irstep_c = cs_c * NR; \ + const inc_t irstep_b = cs_b * NR; \ + */ \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of cache lines between the cores' caches. */ \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Parse and interpret the contents of the rntm_t object to properly + set the ways of parallelism for each loop. */ \ + /*bli_rntm_set_ways_from_rntm_sup( m, n, k, rntm );*/ \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. An alternative way of initializing the + mem_t entries is: + + bli_mem_clear( &mem_a ); \ + bli_mem_clear( &mem_b ); \ + */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. + NOTE: These bszid_t values, and their order, match that of the bp + algorithm (variant 2) because they are not used to query actual + blocksizes but rather query the ways of parallelism for the various + loops. For example, the 2nd loop in variant 1 partitions in the m + dimension (in increments of MR), but parallelizes that m dimension + with BLIS_JR_NT. The only difference is that the _packa and _packb + arrays have been adjusted for the semantic difference in order in + which packa and packb nodes are encountered in the thrinfo tree. + That is, this panel-block algorithm partitions an NC x KC submatrix + of A to be packed in the 4th loop, and a KC x MC submatrix of B + to be packed in the 3rd loop. */ \ + /* 5thloop 4thloop packa 3rdloop packb 2ndloop 1stloop ukrloop */ \ + bszid_t bszids_nopack[6] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t bszids_packa [7] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t bszids_packb [7] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t bszids_packab[8] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t* restrict bszids; \ +\ + /* Set the bszids pointer to the correct bszids array above based on which + matrices (if any) are being packed. */ \ + if ( packa ) { if ( packb ) bszids = bszids_packab; \ + else bszids = bszids_packa; } \ + else { if ( packb ) bszids = bszids_packb; \ + else bszids = bszids_nopack; } \ +\ + /* Determine whether we are using more than one thread. */ \ + const bool is_mt = ( bli_rntm_calc_num_threads( rntm ) > 1 ); \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_jc = bszids; \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, m, MR, FALSE, &jc_start, &jc_end ); \ + const dim_t m_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( m_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = m_local % NC; \ +\ + /* Loop over the m dimension (NC rows/columns at a time). */ \ + /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict a_jc = a_00 + jj * jcstep_a; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_pc = &bszids_jc[1]; \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + /*for ( dim_t pp = 0; pp < pc_iter; pp += 1 )*/ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_jc + pp * pcstep_a; \ + ctype* restrict b_pc = b_00 + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing A. If we won't be packing A, we alias to + the _pc variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ \ + bszid_t* restrict bszids_pa; \ + if ( packa ) { bszids_pa = &bszids_pc[1]; \ + thread_pa = bli_thrinfo_sub_node( thread_pc ); } \ + else { bszids_pa = &bszids_pc[0]; \ + thread_pa = thread_pc; } \ +\ + /* Determine the packing buffer and related parameters for matrix + A. (If A will not be packed, then a_use will be set to point to + a and the _a_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. + NOTE: packing matrix A in this panel-block algorithm corresponds + to packing matrix B in the block-panel algorithm. */ \ + PASTEMAC(ch,packm_sup_a) \ + ( \ + packa, \ + BLIS_BUFFER_FOR_B_PANEL, /* This algorithm packs matrix A to */ \ + stor_id, /* a "panel of B". */ \ + BLIS_NO_TRANSPOSE, \ + NC, KC, /* This "panel of B" is (at most) NC x KC. */ \ + nc_cur, kc_cur, MR, \ + &one_local, \ + a_pc, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_pc_use = a_use; \ +\ + /* We don't need to embed the panel stride of A within the auxinfo_t + object because this variant iterates through A in the jr loop, + which occurs here, within the macrokernel, not within the + millikernel. */ \ + /*bli_auxinfo_set_ps_a( ps_a_use, &aux );*/ \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_ic = &bszids_pa[1]; \ + thread_ic = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, n, NR, FALSE, &ic_start, &ic_end ); \ + const dim_t n_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( n_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = n_local % MC; \ +\ + /* Loop over the n dimension (MC rows at a time). */ \ + /*for ( dim_t ii = 0; ii < ic_iter; ii += 1 )*/ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict b_ic = b_pc + ii * icstep_b; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing A. If we won't be packing A, we alias to + the _pc variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ \ + bszid_t* restrict bszids_pb; \ + if ( packb ) { bszids_pb = &bszids_ic[1]; \ + thread_pb = bli_thrinfo_sub_node( thread_ic ); } \ + else { bszids_pb = &bszids_ic[0]; \ + thread_pb = thread_ic; } \ +\ + /* Determine the packing buffer and related parameters for matrix + B. (If B will not be packed, then b_use will be set to point to + b and the _b_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. + NOTE: packing matrix B in this panel-block algorithm corresponds + to packing matrix A in the block-panel algorithm. */ \ + PASTEMAC(ch,packm_sup_b) \ + ( \ + packb, \ + BLIS_BUFFER_FOR_A_BLOCK, /* This algorithm packs matrix B to */ \ + stor_id, /* a "block of A". */ \ + BLIS_NO_TRANSPOSE, \ + KC, MC, /* This "block of A" is (at most) KC x MC. */ \ + kc_cur, mc_cur, NR, \ + &one_local, \ + b_ic, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_ic_use = b_use; \ +\ + /* Embed the panel stride of B within the auxinfo_t object. The + millikernel will query and use this to iterate through + micropanels of B. */ \ + bli_auxinfo_set_ps_b( ps_b_use, &aux ); \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_jr = &bszids_pb[1]; \ + thread_jr = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + MR - 1 ) / MR; \ + dim_t jr_left = nc_cur % MR; \ +\ + /* An optimization: allow the last jr iteration to contain up to MRE + rows of C and A. (If MRE > MR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. NOTE: We forgo this optimization when packing A + since packing an extended edge case is not yet supported. */ \ + if ( !packa && !is_mt ) \ + if ( MRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= MRE ) \ + { \ + jr_iter--; jr_left += MR; \ + } \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the m dimension (NR columns at a time). */ \ + /*for ( dim_t j = 0; j < jr_iter; j += 1 )*/ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? MR : jr_left ); \ +\ + /* + ctype* restrict a_jr = a_pc + j * jrstep_a; \ + */ \ + ctype* restrict a_jr = a_pc_use + j * ps_a_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* + const dim_t ir_iter = ( mc_cur + NR - 1 ) / NR; \ + const dim_t ir_left = mc_cur % NR; \ + */ \ +\ + /* Loop over the n dimension (MR rows at a time). */ \ + { \ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + nr_cur, /* Notice: nr_cur <= MR. */ \ + mc_cur, /* Recall: mc_cur partitions the n dimension! */ \ + kc_cur, \ + alpha_cast, \ + a_jr, rs_a_use, cs_a_use, \ + b_ic_use, rs_b_use, cs_b_use, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* NOTE: This barrier is only needed if we are packing A (since + that matrix is packed within the pc loop of this variant). */ \ + if ( packa ) bli_thread_barrier( thread_pa ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTEMAC(ch,packm_sup_finalize_mem_a) \ + ( \ + packa, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTEMAC(ch,packm_sup_finalize_mem_b) \ + ( \ + packb, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var1n ) + + +// +// -- var2m -------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var2m,gemmsup_ref_var2m); + +void bli_gemmsup_ref_var2m + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + +#else + const num_t dt = bli_obj_dt( c ); + + const bool packa = bli_rntm_pack_a( rntm ); + const bool packb = bli_rntm_pack_b( rntm ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var2m[dt]; + +#if 1 + // Optimize some storage/packing cases by transforming them into others. + // These optimizations are expressed by changing trans and/or eff_id. + bli_gemmsup_ref_var1n2m_opt_cases( dt, &trans, packa, packb, &eff_id, cntx ); +#endif + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + packa, + packb, + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm, + thread + ); + } + else + { + // Invoke the function (transposing the operation). + f + ( + packb, // swap the pack values. + packa, + conjb, // swap the conj values. + conja, + n, // swap the m and n dimensions. + m, + k, + buf_alpha, + buf_b, cs_b, rs_b, // swap the positions of A and B. + buf_a, cs_a, rs_a, // swap the strides of A and B. + buf_beta, + buf_c, cs_c, rs_c, // swap the strides of C. + bli_stor3_trans( eff_id ), // transpose the stor3_t id. + cntx, + rntm, + thread + ); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + bool packa, \ + bool packb, \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ + return; \ + } \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + dim_t KC; \ + if ( packa && packb ) \ + { \ + KC = KC0; \ + } \ + else if ( packb ) \ + { \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( stor_id == BLIS_RCR || \ + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; \ + else KC = KC0; \ + } \ + else if ( packa ) \ + { \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = (( KC0 / 2 ) / 2 ) * 2; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( stor_id == BLIS_RCR || \ + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; \ + else KC = KC0; \ + } \ + else /* if ( !packa && !packb ) */ \ + { \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( m <= MR && n <= NR ) KC = KC0; \ + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ + else KC = (( KC0 / 5 ) / 4 ) * 4; \ + } \ +\ + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ \ + const dim_t NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); \ + const dim_t NRE = NRM - NR; \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + /* + const inc_t jrstep_b = cs_b * NR; \ + ( void )jrstep_b; \ +\ + const inc_t irstep_c = rs_c * MR; \ + const inc_t irstep_a = rs_a * MR; \ + */ \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of cache lines between the cores' caches. */ \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Parse and interpret the contents of the rntm_t object to properly + set the ways of parallelism for each loop. */ \ + /*bli_rntm_set_ways_from_rntm_sup( m, n, k, rntm );*/ \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. An alternative way of initializing the + mem_t entries is: + + bli_mem_clear( &mem_a ); \ + bli_mem_clear( &mem_b ); \ + */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + /* 5thloop 4thloop packb 3rdloop packa 2ndloop 1stloop ukrloop */ \ + bszid_t bszids_nopack[6] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t bszids_packa [7] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t bszids_packb [7] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t bszids_packab[8] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; \ + bszid_t* restrict bszids; \ +\ + /* Set the bszids pointer to the correct bszids array above based on which + matrices (if any) are being packed. */ \ + if ( packa ) { if ( packb ) bszids = bszids_packab; \ + else bszids = bszids_packa; } \ + else { if ( packb ) bszids = bszids_packb; \ + else bszids = bszids_nopack; } \ +\ + /* Determine whether we are using more than one thread. */ \ + const bool is_mt = ( bli_rntm_calc_num_threads( rntm ) > 1 ); \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_jc = bszids; \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_pc = &bszids_jc[1]; \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + /*for ( dim_t pp = 0; pp < pc_iter; pp += 1 )*/ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing B, we alias to + the _pc variables so that code further down can unconditionally + reference the _pb variables. Note that *if* we will be packing + B, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ \ + bszid_t* restrict bszids_pb; \ + if ( packb ) { bszids_pb = &bszids_pc[1]; \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); } \ + else { bszids_pb = &bszids_pc[0]; \ + thread_pb = thread_pc; } \ +\ + /* Determine the packing buffer and related parameters for matrix + B. (If B will not be packed, then a_use will be set to point to + b and the _b_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ \ + PASTEMAC(ch,packm_sup_b) \ + ( \ + packb, \ + BLIS_BUFFER_FOR_B_PANEL, /* This algorithm packs matrix B to */ \ + stor_id, /* a "panel of B." */ \ + BLIS_NO_TRANSPOSE, \ + KC, NC, /* This "panel of B" is (at most) KC x NC. */ \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* We don't need to embed the panel stride of B within the auxinfo_t + object because this variant iterates through B in the jr loop, + which occurs here, within the macrokernel, not within the + millikernel. */ \ + /*bli_auxinfo_set_ps_b( ps_b_use, &aux );*/ \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_ic = &bszids_pb[1]; \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + /*for ( dim_t ii = 0; ii < ic_iter; ii += 1 )*/ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing A, we alias to + the _ic variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ \ + bszid_t* restrict bszids_pa; \ + if ( packa ) { bszids_pa = &bszids_ic[1]; \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); } \ + else { bszids_pa = &bszids_ic[0]; \ + thread_pa = thread_ic; } \ +\ + /* Determine the packing buffer and related parameters for matrix + A. (If A will not be packed, then a_use will be set to point to + a and the _a_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ \ + PASTEMAC(ch,packm_sup_a) \ + ( \ + packa, \ + BLIS_BUFFER_FOR_A_BLOCK, /* This algorithm packs matrix A to */ \ + stor_id, /* a "block of A." */ \ + BLIS_NO_TRANSPOSE, \ + MC, KC, /* This "block of A" is (at most) MC x KC. */ \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Embed the panel stride of A within the auxinfo_t object. The + millikernel will query and use this to iterate through + micropanels of A (if needed). */ \ + bli_auxinfo_set_ps_a( ps_a_use, &aux ); \ +\ + /* Grow the thrinfo_t tree. */ \ + bszid_t* restrict bszids_jr = &bszids_pa[1]; \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* An optimization: allow the last jr iteration to contain up to NRE + columns of C and B. (If NRE > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. NOTE: We forgo this optimization when packing B + since packing an extended edge case is not yet supported. */ \ + if ( !packb && !is_mt ) \ + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) \ + { \ + jr_iter--; jr_left += NR; \ + } \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + /*for ( dim_t j = 0; j < jr_iter; j += 1 )*/ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + /* + ctype* restrict b_jr = b_pc_use + j * jrstep_b; \ + */ \ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* + const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + const dim_t ir_left = mc_cur % MR; \ + */ \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + { \ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mc_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ic_use, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* NOTE: This barrier is only needed if we are packing B (since + that matrix is packed within the pc loop of this variant). */ \ + if ( packb ) bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTEMAC(ch,packm_sup_finalize_mem_a) \ + ( \ + packa, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTEMAC(ch,packm_sup_finalize_mem_b) \ + ( \ + packb, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var2m ) + diff --git a/frame/3/bli_l3_sup_vars.h b/frame/3/bli_l3_sup_vars.h new file mode 100644 index 0000000000..7c315192d5 --- /dev/null +++ b/frame/3/bli_l3_sup_vars.h @@ -0,0 +1,206 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype object-based interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC0(opname) \ + ( \ + trans_t trans, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + stor3_t eff_id, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +GENPROT( gemmsup_ref_var1 ) +GENPROT( gemmsup_ref_var2 ) + +GENPROT( gemmsup_ref_var1n ) +GENPROT( gemmsup_ref_var2m ) + + +// +// Prototype BLAS-like interfaces with void pointer operands. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +INSERT_GENTPROT_BASIC0( gemmsup_ref_var1 ) +INSERT_GENTPROT_BASIC0( gemmsup_ref_var2 ) + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + bool packa, \ + bool packb, \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +INSERT_GENTPROT_BASIC0( gemmsup_ref_var1n ) +INSERT_GENTPROT_BASIC0( gemmsup_ref_var2m ) + +// ----------------------------------------------------------------------------- + +BLIS_INLINE void bli_gemmsup_ref_var1n2m_opt_cases + ( + num_t dt, + trans_t* trans, + bool packa, + bool packb, + stor3_t* eff_id, + cntx_t* cntx + ) +{ + const bool row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, *eff_id, cntx ); + + // Handle row- and column-preferrential kernels separately. + if ( row_pref ) + { + if ( packa && packb ) + { + if ( *eff_id == BLIS_RRC ) + { + // Since C is already row-stored, we can use BLIS_RRR kernel instead. + *eff_id = BLIS_RRR; + } + else if ( *eff_id == BLIS_CRC ) + { + // BLIS_RRC when transposed below (both matrices still packed). + // This allows us to use the BLIS_RRR kernel instead. + *eff_id = BLIS_CCC; // BLIS_RRR when transposed below. + } + else if ( *eff_id == BLIS_CRR ) + { + // Induce a transpose to make C row-stored. + // BLIS_RCC when transposed below (both matrices still packed). + // This allows us to use the BLIS_RRR kernel instead. + *trans = bli_trans_toggled( *trans ); + *eff_id = BLIS_CCC; // BLIS_RRR when transposed below. + } + } + else if ( packb ) + { + if ( *eff_id == BLIS_RRC ) + { + // Since C is already row-stored, we can use BLIS_RRR kernel instead. + *eff_id = BLIS_RRR; + } + else if ( *eff_id == BLIS_CRC ) + { + // BLIS_RRC when transposed below (with packa instead of packb). + // No transformation is beneficial here. + } + else if ( *eff_id == BLIS_RCC ) + { + // C is already row-stored; cancel transposition and use BLIS_RCR + // kernel instead. + *trans = bli_trans_toggled( *trans ); + *eff_id = BLIS_RCR; + } + #if 0 + // This transformation performs poorly. Theory: packing A (formerly B) + // when eff_id == BLIS_RCC (formerly BLIS_CRR) to row storage is slow + // and kills the performance? + else if ( eff_id == BLIS_CRR ) + { + trans = bli_trans_toggled( trans ); + eff_id = BLIS_CRC; // BLIS_RRC when transposed below. + } + #endif + } + else if ( packa ) + { + if ( *eff_id == BLIS_CRR ) + { + // Induce a transpose to make C row-stored. + // BLIS_RCC when transposed below (both matrices still packed). + // This allows us to use the BLIS_RRR kernel instead. + *trans = bli_trans_toggled( *trans ); + *eff_id = BLIS_CCR; // BLIS_RCR when transposed below. + } + } + } + else + { + //bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); + printf( "libblis: sup var1n2m_opt_cases not yet implemented for column-preferential kernels.\n" ); + bli_abort(); + } +} + diff --git a/frame/3/bli_l3_tapi.c b/frame/3/bli_l3_tapi.c index 4eeba19713..afec5b677a 100644 --- a/frame/3/bli_l3_tapi.c +++ b/frame/3/bli_l3_tapi.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,18 +32,16 @@ */ -// Guard the function definitions so that they are only compiled when -// #included from files that define the typed API macros. -#ifdef BLIS_ENABLE_TAPI +#include "blis.h" // -// Define BLAS-like interfaces with typed operands. +// Define BLAS-like interfaces with typed operands (basic). // #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ trans_t transa, \ trans_t transb, \ @@ -55,52 +53,70 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + transa, \ + transb, \ + m, n, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } INSERT_GENTFUNC_BASIC0( gemm ) +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ + ( \ + uploc, \ + transa, \ + transb, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( gemmt ) + + #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, struca ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -113,46 +129,24 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_conj( conja, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( struca, &ao ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploa, \ + conja, \ + transb, \ + m, n, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -163,7 +157,7 @@ INSERT_GENTFUNC_BASIC( symm, BLIS_SYMMETRIC ) #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -173,41 +167,21 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, betao, co; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt_r, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -217,7 +191,7 @@ INSERT_GENTFUNCR_BASIC0( herk ) #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -229,46 +203,23 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + transb, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -278,7 +229,7 @@ INSERT_GENTFUNCR_BASIC0( her2k ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -288,40 +239,21 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, betao, co; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -331,7 +263,7 @@ INSERT_GENTFUNC_BASIC0( syrk ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -343,45 +275,23 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + transb, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -391,7 +301,7 @@ INSERT_GENTFUNC_BASIC0( syr2k ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -405,47 +315,25 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploa, \ + transa, \ + diaga, \ + transb, \ + m, n, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -455,7 +343,7 @@ INSERT_GENTFUNC_BASIC0( trmm3 ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -466,46 +354,25 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* b, inc_t rs_b, inc_t cs_b \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo; \ -\ - dim_t mn_a; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, n, b, rs_b, cs_b, &bo ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side, \ - &alphao, \ - &ao, \ - &bo, \ - cntx, \ - rntm \ + uploa, \ + transa, \ + diaga, \ + m, n, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + NULL, \ + NULL \ ); \ } INSERT_GENTFUNC_BASIC0( trmm ) INSERT_GENTFUNC_BASIC0( trsm ) - -#endif - diff --git a/frame/3/bli_l3_tapi.h b/frame/3/bli_l3_tapi.h index a809c2a68e..4b35040018 100644 --- a/frame/3/bli_l3_tapi.h +++ b/frame/3/bli_l3_tapi.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,13 +35,13 @@ // -// Prototype BLAS-like interfaces with typed operands. +// Prototype BLAS-like interfaces with typed operands (basic). // #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ trans_t transa, \ trans_t transb, \ @@ -52,16 +53,14 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( gemm ) - #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -74,7 +73,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( hemm ) @@ -84,7 +82,7 @@ INSERT_GENTPROT_BASIC0( symm ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -94,7 +92,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROTR_BASIC0( herk ) @@ -103,7 +100,7 @@ INSERT_GENTPROTR_BASIC0( herk ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -115,7 +112,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROTR_BASIC0( her2k ) @@ -124,7 +120,7 @@ INSERT_GENTPROTR_BASIC0( her2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -134,7 +130,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( syrk ) @@ -143,7 +138,7 @@ INSERT_GENTPROT_BASIC0( syrk ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -155,16 +150,16 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); +INSERT_GENTPROT_BASIC0( gemmt ) INSERT_GENTPROT_BASIC0( syr2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -178,7 +173,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( trmm3 ) @@ -187,7 +181,7 @@ INSERT_GENTPROT_BASIC0( trmm3 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -198,7 +192,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* b, inc_t rs_b, inc_t cs_b \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( trmm ) diff --git a/frame/3/bli_l3_tapi_ex.c b/frame/3/bli_l3_tapi_ex.c index 609bf8e78d..f6a52fb5e9 100644 --- a/frame/3/bli_l3_tapi_ex.c +++ b/frame/3/bli_l3_tapi_ex.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,13 +35,553 @@ #include "blis.h" -// Include cpp macros that instantiate the API definition templates as -// having expert parameters. -#include "bli_tapi_ex.h" +// +// Define BLAS-like interfaces with typed operands (expert). +// -// Define the macro protecting the typed API definitions. -#define BLIS_ENABLE_TAPI +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} -// Include the typed API definitions here. -#include "bli_l3_tapi.c" +INSERT_GENTFUNC_BASIC0( gemm ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, struca ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + uplo_t uploa, \ + conj_t conja, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dim_with_side( side, m, n, &mn_a ); \ + bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploa, &ao ); \ + bli_obj_set_conj( conja, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC( hemm, BLIS_HERMITIAN ) +INSERT_GENTFUNC_BASIC( symm, BLIS_SYMMETRIC ) + + +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + dim_t m, \ + dim_t k, \ + ctype_r* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype_r* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ +\ + bli_obj_init_finish_1x1( dt_r, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ +\ + bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNCR_BASIC0( herk ) + + +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype_r* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNCR_BASIC0( her2k ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ +\ + bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( syrk ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( syr2k ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, m, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( gemmt ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dim_with_side( side, m, n, &mn_a ); \ + bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploa, &ao ); \ + bli_obj_set_diag( diaga, &ao ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( trmm3 ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn_a; \ +\ + bli_set_dim_with_side( side, m, n, &mn_a ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, n, b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( uploa, &ao ); \ + bli_obj_set_diag( diaga, &ao ); \ + bli_obj_set_conjtrans( transa, &ao ); \ +\ + bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side, \ + &alphao, \ + &ao, \ + &bo, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( trmm ) +INSERT_GENTFUNC_BASIC0( trsm ) diff --git a/frame/ind/tapi/bli_l3_ind_tapi.h b/frame/3/bli_l3_tapi_ex.h similarity index 63% rename from frame/ind/tapi/bli_l3_ind_tapi.h rename to frame/3/bli_l3_tapi_ex.h index 49ff6a8739..1ab0a8ff1a 100644 --- a/frame/ind/tapi/bli_l3_ind_tapi.h +++ b/frame/3/bli_l3_tapi_ex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,10 +34,14 @@ */ +// +// Prototype BLAS-like interfaces with typed operands (expert). +// + #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ trans_t transa, \ trans_t transb, \ @@ -52,18 +57,12 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( gemm3mh ) -INSERT_GENTPROT_BASIC0( gemm3m1 ) -INSERT_GENTPROT_BASIC0( gemm4mh ) -INSERT_GENTPROT_BASIC0( gemm4mb ) -INSERT_GENTPROT_BASIC0( gemm4m1 ) -INSERT_GENTPROT_BASIC0( gemm1m ) - +INSERT_GENTPROT_BASIC0( gemm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -80,144 +79,99 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( hemm3mh ) -INSERT_GENTPROT_BASIC0( hemm3m1 ) -INSERT_GENTPROT_BASIC0( hemm4mh ) -INSERT_GENTPROT_BASIC0( hemm4m1 ) -INSERT_GENTPROT_BASIC0( hemm1m ) +INSERT_GENTPROT_BASIC0( hemm ) +INSERT_GENTPROT_BASIC0( symm ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ - trans_t transb, \ dim_t m, \ dim_t k, \ - ctype* alpha, \ + ctype_r* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ cntx_t* cntx, \ - rntm_t* rntmx \ + rntm_t* rntm \ ); -INSERT_GENTPROTR_BASIC0( her2k3mh ) -INSERT_GENTPROTR_BASIC0( her2k3m1 ) -INSERT_GENTPROTR_BASIC0( her2k4mh ) -INSERT_GENTPROTR_BASIC0( her2k4m1 ) -INSERT_GENTPROTR_BASIC0( her2k1m ) +INSERT_GENTPROTR_BASIC0( herk ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ + trans_t transb, \ dim_t m, \ dim_t k, \ - ctype_r* alpha, \ + ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntmx \ - ); - -INSERT_GENTPROTR_BASIC0( herk3mh ) -INSERT_GENTPROTR_BASIC0( herk3m1 ) -INSERT_GENTPROTR_BASIC0( herk4mh ) -INSERT_GENTPROTR_BASIC0( herk4m1 ) -INSERT_GENTPROTR_BASIC0( herk1m ) - - -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - conj_t conja, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ + cntx_t* cntx, \ + rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( symm3mh ) -INSERT_GENTPROT_BASIC0( symm3m1 ) -INSERT_GENTPROT_BASIC0( symm4mh ) -INSERT_GENTPROT_BASIC0( symm4m1 ) -INSERT_GENTPROT_BASIC0( symm1m ) +INSERT_GENTPROTR_BASIC0( her2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ - trans_t transb, \ dim_t m, \ dim_t k, \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ cntx_t* cntx, \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( syr2k3mh ) -INSERT_GENTPROT_BASIC0( syr2k3m1 ) -INSERT_GENTPROT_BASIC0( syr2k4mh ) -INSERT_GENTPROT_BASIC0( syr2k4m1 ) -INSERT_GENTPROT_BASIC0( syr2k1m ) +INSERT_GENTPROT_BASIC0( syrk ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ + trans_t transb, \ dim_t m, \ dim_t k, \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ cntx_t* cntx, \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( syrk3mh ) -INSERT_GENTPROT_BASIC0( syrk3m1 ) -INSERT_GENTPROT_BASIC0( syrk4mh ) -INSERT_GENTPROT_BASIC0( syrk4m1 ) -INSERT_GENTPROT_BASIC0( syrk1m ) +INSERT_GENTPROT_BASIC0( gemmt ) +INSERT_GENTPROT_BASIC0( syr2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -235,40 +189,13 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( trmm33mh ) -INSERT_GENTPROT_BASIC0( trmm33m1 ) -INSERT_GENTPROT_BASIC0( trmm34mh ) -INSERT_GENTPROT_BASIC0( trmm34m1 ) -INSERT_GENTPROT_BASIC0( trmm31m ) - - -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ); - -INSERT_GENTPROT_BASIC0( trmm3m1 ) -INSERT_GENTPROT_BASIC0( trmm4m1 ) -INSERT_GENTPROT_BASIC0( trmm1m ) +INSERT_GENTPROT_BASIC0( trmm3 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -283,7 +210,6 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( trsm3m1 ) -INSERT_GENTPROT_BASIC0( trsm4m1 ) -INSERT_GENTPROT_BASIC0( trsm1m ) +INSERT_GENTPROT_BASIC0( trmm ) +INSERT_GENTPROT_BASIC0( trsm ) diff --git a/frame/3/bli_l3_thrinfo.c b/frame/3/bli_l3_thrinfo.c index 1d876d50f1..f866cfd4c5 100644 --- a/frame/3/bli_l3_thrinfo.c +++ b/frame/3/bli_l3_thrinfo.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,6 +53,15 @@ void bli_l3_thrinfo_free bli_thrinfo_free( rntm, thread ); } +void bli_l3_sup_thrinfo_free + ( + rntm_t* rntm, + thrinfo_t* thread + ) +{ + bli_thrinfo_free( rntm, thread ); +} + // ----------------------------------------------------------------------------- void bli_l3_thrinfo_create_root @@ -94,40 +103,157 @@ void bli_l3_thrinfo_create_root // ----------------------------------------------------------------------------- +void bli_l3_sup_thrinfo_create_root + ( + dim_t id, + thrcomm_t* gl_comm, + rntm_t* rntm, + thrinfo_t** thread + ) +{ + // Query the global communicator for the total number of threads to use. + dim_t n_threads = bli_thrcomm_num_threads( gl_comm ); + + // Use the thread id passed in as the global communicator id. + dim_t gl_comm_id = id; + + // Use the BLIS_NC blocksize id to query the top-most ways of parallelism + // to obtain. Note that hard-coding BLIS_NC like this is a little bit of a + // hack, but it works fine since both of the sup algorithms (bp and pb) use + // the cache blocksizes down to the 3rd loop. (See the definitions of + // bli_rntm_calc_num_threads_bp() and bli_rntm_calc_num_threads_pb() for + // a concise enumeration of these bszid_t ids.) + const bszid_t bszid = BLIS_NC; + dim_t xx_way = bli_rntm_ways_for( BLIS_NC, rntm ); + + // Determine the work id for this thrinfo_t node. + dim_t work_id = gl_comm_id / ( n_threads / xx_way ); + + // Create the root thrinfo_t node. + *thread = bli_thrinfo_create + ( + rntm, + gl_comm, + gl_comm_id, + xx_way, + work_id, + TRUE, + bszid, + NULL + ); +} + +// ----------------------------------------------------------------------------- + +void bli_l3_sup_thrinfo_update_root + ( + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // Query the current root for the total number of threads to use. + const dim_t n_threads = bli_thread_num_threads( thread ); + + // Query the current root for the (global) comm id. + const dim_t gl_comm_id = bli_thread_ocomm_id( thread ); + + // Query the rntm_t for the updated number of ways of parallelism. + const dim_t xx_way = bli_rntm_ways_for( BLIS_NC, rntm ); + + // Recompute the work id for this thrinfo_t node using the updated + // number of ways of parallelism. + dim_t work_id = gl_comm_id / ( n_threads / xx_way ); + + // Save the updated ways of parallelism and work id to the thrinfo_t node. + bli_thrinfo_set_n_way( xx_way, thread ); + bli_thrinfo_set_work_id( work_id, thread ); +} + +// ----------------------------------------------------------------------------- + void bli_l3_thrinfo_print_gemm_paths ( thrinfo_t** threads ) { + // In order to query the number of threads, we query the only thread we + // know exists: thread 0. dim_t n_threads = bli_thread_num_threads( threads[0] ); - dim_t gl_id; - - thrinfo_t* jc_info = threads[0]; - thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info ); - thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info ); - thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info ); - thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info ); - thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info ); - thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info ); - - dim_t jc_way = bli_thread_n_way( jc_info ); - dim_t pc_way = bli_thread_n_way( pc_info ); - dim_t pb_way = bli_thread_n_way( pb_info ); - dim_t ic_way = bli_thread_n_way( ic_info ); - dim_t pa_way = bli_thread_n_way( pa_info ); - dim_t jr_way = bli_thread_n_way( jr_info ); - dim_t ir_way = bli_thread_n_way( ir_info ); - - dim_t jc_nt = bli_thread_num_threads( jc_info ); - dim_t pc_nt = bli_thread_num_threads( pc_info ); - dim_t pb_nt = bli_thread_num_threads( pb_info ); - dim_t ic_nt = bli_thread_num_threads( ic_info ); - dim_t pa_nt = bli_thread_num_threads( pa_info ); - dim_t jr_nt = bli_thread_num_threads( jr_info ); - dim_t ir_nt = bli_thread_num_threads( ir_info ); + + // For the purposes of printing the "header" information that is common + // to the various instances of a thrinfo_t (ie: across all threads), we + // choose the last thread in case the problem is so small that there is + // only an "edge" case, which will always be assigned to the last thread + // (at least for higher levels of partitioning). + thrinfo_t* jc_info = threads[n_threads-1]; + thrinfo_t* pc_info = NULL; + thrinfo_t* pb_info = NULL; + thrinfo_t* ic_info = NULL; + thrinfo_t* pa_info = NULL; + thrinfo_t* jr_info = NULL; + thrinfo_t* ir_info = NULL; + + // Initialize the n_ways and n_threads fields of each thrinfo_t "level" + // to -1. More than likely, these will all be overwritten with meaningful + // values, but in case some thrinfo_t trees are not fully built (see + // next commnet), these will be the placeholder values. + dim_t jc_way = -1, pc_way = -1, pb_way = -1, ic_way = -1, + pa_way = -1, jr_way = -1, ir_way = -1; + + dim_t jc_nt = -1, pc_nt = -1, pb_nt = -1, ic_nt = -1, + pa_nt = -1, jr_nt = -1, ir_nt = -1; + + // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads + // may not fully build their thrinfo_t structures--specifically when the + // dimension being parallelized is not large enough for each thread to have + // even one unit of work (where as unit is usually a single micropanel's + // width, MR or NR). + + if ( !jc_info ) goto print_header; + + jc_way = bli_thread_n_way( jc_info ); + jc_nt = bli_thread_num_threads( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); + + if ( !pc_info ) goto print_header; + + pc_way = bli_thread_n_way( pc_info ); + pc_nt = bli_thread_num_threads( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + + if ( !pb_info ) goto print_header; + + pb_way = bli_thread_n_way( pb_info ); + pb_nt = bli_thread_num_threads( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_header; + + ic_way = bli_thread_n_way( ic_info ); + ic_nt = bli_thread_num_threads( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + + if ( !pa_info ) goto print_header; + + pa_way = bli_thread_n_way( pa_info ); + pa_nt = bli_thread_num_threads( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_header; + + jr_way = bli_thread_n_way( jr_info ); + jr_nt = bli_thread_num_threads( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_header; + + ir_way = bli_thread_n_way( ir_info ); + ir_nt = bli_thread_num_threads( ir_info ); + + print_header: printf( " jc kc pb ic pa jr ir\n" ); - printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", + printf( "xx_nt: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n", ( unsigned long )jc_nt, ( unsigned long )pc_nt, ( unsigned long )pb_nt, @@ -135,7 +261,7 @@ void bli_l3_thrinfo_print_gemm_paths ( unsigned long )pa_nt, ( unsigned long )jr_nt, ( unsigned long )ir_nt ); - printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", + printf( "xx_way: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n", ( unsigned long )jc_way, ( unsigned long )pc_way, ( unsigned long )pb_way, @@ -145,116 +271,59 @@ void bli_l3_thrinfo_print_gemm_paths ( unsigned long )ir_way ); printf( "============================================\n" ); - dim_t jc_comm_id; - dim_t pc_comm_id; - dim_t pb_comm_id; - dim_t ic_comm_id; - dim_t pa_comm_id; - dim_t jr_comm_id; - dim_t ir_comm_id; - - dim_t jc_work_id; - dim_t pc_work_id; - dim_t pb_work_id; - dim_t ic_work_id; - dim_t pa_work_id; - dim_t jr_work_id; - dim_t ir_work_id; - - for ( gl_id = 0; gl_id < n_threads; ++gl_id ) + for ( dim_t gl_id = 0; gl_id < n_threads; ++gl_id ) { jc_info = threads[gl_id]; - // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads - // may not fully build their thrinfo_t structures--specifically when the - // dimension being parallelized is not large enough for each thread to have - // even one unit of work (where as unit is usually a single micropanel's - // width, MR or NR). - if ( !jc_info ) - { - jc_comm_id = pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - jc_work_id = pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - jc_comm_id = bli_thread_ocomm_id( jc_info ); - jc_work_id = bli_thread_work_id( jc_info ); - pc_info = bli_thrinfo_sub_node( jc_info ); + dim_t jc_comm_id = -1, pc_comm_id = -1, pb_comm_id = -1, ic_comm_id = -1, + pa_comm_id = -1, jr_comm_id = -1, ir_comm_id = -1; - if ( !pc_info ) - { - pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - pc_comm_id = bli_thread_ocomm_id( pc_info ); - pc_work_id = bli_thread_work_id( pc_info ); - pb_info = bli_thrinfo_sub_node( pc_info ); + dim_t jc_work_id = -1, pc_work_id = -1, pb_work_id = -1, ic_work_id = -1, + pa_work_id = -1, jr_work_id = -1, ir_work_id = -1; - if ( !pb_info ) - { - pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - pb_comm_id = bli_thread_ocomm_id( pb_info ); - pb_work_id = bli_thread_work_id( pb_info ); - ic_info = bli_thrinfo_sub_node( pb_info ); + if ( !jc_info ) goto print_thrinfo; - if ( !ic_info ) - { - ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - ic_comm_id = bli_thread_ocomm_id( ic_info ); - ic_work_id = bli_thread_work_id( ic_info ); - pa_info = bli_thrinfo_sub_node( ic_info ); + jc_comm_id = bli_thread_ocomm_id( jc_info ); + jc_work_id = bli_thread_work_id( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); - if ( !pa_info ) - { - pa_comm_id = jr_comm_id = ir_comm_id = -1; - pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - pa_comm_id = bli_thread_ocomm_id( pa_info ); - pa_work_id = bli_thread_work_id( pa_info ); - jr_info = bli_thrinfo_sub_node( pa_info ); + if ( !pc_info ) goto print_thrinfo; - if ( !jr_info ) - { - jr_comm_id = ir_comm_id = -1; - jr_work_id = ir_work_id = -1; - } - else - { - jr_comm_id = bli_thread_ocomm_id( jr_info ); - jr_work_id = bli_thread_work_id( jr_info ); - ir_info = bli_thrinfo_sub_node( jr_info ); + pc_comm_id = bli_thread_ocomm_id( pc_info ); + pc_work_id = bli_thread_work_id( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); - if ( !ir_info ) - { - ir_comm_id = -1; - ir_work_id = -1; - } - else - { - ir_comm_id = bli_thread_ocomm_id( ir_info ); - ir_work_id = bli_thread_work_id( ir_info ); - } - } - } - } - } - } - } + if ( !pb_info ) goto print_thrinfo; + + pb_comm_id = bli_thread_ocomm_id( pb_info ); + pb_work_id = bli_thread_work_id( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_thrinfo; + + ic_comm_id = bli_thread_ocomm_id( ic_info ); + ic_work_id = bli_thread_work_id( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + + if ( !pa_info ) goto print_thrinfo; + + pa_comm_id = bli_thread_ocomm_id( pa_info ); + pa_work_id = bli_thread_work_id( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_thrinfo; + + jr_comm_id = bli_thread_ocomm_id( jr_info ); + jr_work_id = bli_thread_work_id( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_thrinfo; + + ir_comm_id = bli_thread_ocomm_id( ir_info ); + ir_work_id = bli_thread_work_id( ir_info ); + + print_thrinfo: - //printf( " gl jc pb kc pa ic jr \n" ); - //printf( " gl jc kc pb ic pa jr \n" ); printf( "comm ids: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n", ( long )jc_comm_id, ( long )pc_comm_id, @@ -285,44 +354,105 @@ void bli_l3_thrinfo_print_trsm_paths thrinfo_t** threads ) { + // In order to query the number of threads, we query the only thread we + // know exists: thread 0. dim_t n_threads = bli_thread_num_threads( threads[0] ); - dim_t gl_id; - - thrinfo_t* jc_info = threads[0]; - thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info ); - thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info ); - thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info ); - - thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info ); - thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info ); - thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info ); - thrinfo_t* pa_info0 = bli_thrinfo_sub_prenode( ic_info ); - thrinfo_t* jr_info0 = ( pa_info0 ? bli_thrinfo_sub_node( pa_info0 ) : NULL ); - thrinfo_t* ir_info0 = ( jr_info0 ? bli_thrinfo_sub_node( jr_info0 ) : NULL ); - - dim_t jc_way = bli_thread_n_way( jc_info ); - dim_t pc_way = bli_thread_n_way( pc_info ); - dim_t pb_way = bli_thread_n_way( pb_info ); - dim_t ic_way = bli_thread_n_way( ic_info ); - - dim_t pa_way = bli_thread_n_way( pa_info ); - dim_t jr_way = bli_thread_n_way( jr_info ); - dim_t ir_way = bli_thread_n_way( ir_info ); - dim_t pa_way0 = ( pa_info0 ? bli_thread_n_way( pa_info0 ) : -1 ); - dim_t jr_way0 = ( jr_info0 ? bli_thread_n_way( jr_info0 ) : -1 ); - dim_t ir_way0 = ( ir_info0 ? bli_thread_n_way( ir_info0 ) : -1 ); - - dim_t jc_nt = bli_thread_num_threads( jc_info ); - dim_t pc_nt = bli_thread_num_threads( pc_info ); - dim_t pb_nt = bli_thread_num_threads( pb_info ); - dim_t ic_nt = bli_thread_num_threads( ic_info ); - - dim_t pa_nt = bli_thread_num_threads( pa_info ); - dim_t jr_nt = bli_thread_num_threads( jr_info ); - dim_t ir_nt = bli_thread_num_threads( ir_info ); - dim_t pa_nt0 = ( pa_info0 ? bli_thread_num_threads( pa_info0 ) : -1 ); - dim_t jr_nt0 = ( jr_info0 ? bli_thread_num_threads( jr_info0 ) : -1 ); - dim_t ir_nt0 = ( ir_info0 ? bli_thread_num_threads( ir_info0 ) : -1 ); + + // For the purposes of printing the "header" information that is common + // to the various instances of a thrinfo_t (ie: across all threads), we + // choose the last thread in case the problem is so small that there is + // only an "edge" case, which will always be assigned to the last thread + // (at least for higher levels of partitioning). + thrinfo_t* jc_info = threads[n_threads-1]; + thrinfo_t* pc_info = NULL; + thrinfo_t* pb_info = NULL; + thrinfo_t* ic_info = NULL; + thrinfo_t* pa_info = NULL; thrinfo_t* pa_info0 = NULL; + thrinfo_t* jr_info = NULL; thrinfo_t* jr_info0 = NULL; + thrinfo_t* ir_info = NULL; thrinfo_t* ir_info0 = NULL; + + // Initialize the n_ways and n_threads fields of each thrinfo_t "level" + // to -1. More than likely, these will all be overwritten with meaningful + // values, but in case some thrinfo_t trees are not fully built (see + // next commnet), these will be the placeholder values. + dim_t jc_way = -1, pc_way = -1, pb_way = -1, ic_way = -1, + pa_way = -1, jr_way = -1, ir_way = -1, + pa_way0 = -1, jr_way0 = -1, ir_way0 = -1; + + dim_t jc_nt = -1, pc_nt = -1, pb_nt = -1, ic_nt = -1, + pa_nt = -1, jr_nt = -1, ir_nt = -1, + pa_nt0 = -1, jr_nt0 = -1, ir_nt0 = -1; + + // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads + // may not fully build their thrinfo_t structures--specifically when the + // dimension being parallelized is not large enough for each thread to have + // even one unit of work (where as unit is usually a single micropanel's + // width, MR or NR). + + if ( !jc_info ) goto print_header; + + jc_way = bli_thread_n_way( jc_info ); + jc_nt = bli_thread_num_threads( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); + + if ( !pc_info ) goto print_header; + + pc_way = bli_thread_n_way( pc_info ); + pc_nt = bli_thread_num_threads( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + + if ( !pb_info ) goto print_header; + + pb_way = bli_thread_n_way( pb_info ); + pb_nt = bli_thread_num_threads( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_header; + + ic_way = bli_thread_n_way( ic_info ); + ic_nt = bli_thread_num_threads( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + pa_info0 = bli_thrinfo_sub_prenode( ic_info ); + + // check_header_prenode: + + if ( !pa_info0 ) goto check_header_node; + + pa_way0 = bli_thread_n_way( pa_info0 ); + pa_nt0 = bli_thread_num_threads( pa_info0 ); + jr_info0 = bli_thrinfo_sub_node( pa_info0 ); + + if ( !jr_info0 ) goto check_header_node; + + jr_way0 = bli_thread_n_way( jr_info0 ); + jr_nt0 = bli_thread_num_threads( jr_info0 ); + ir_info0 = bli_thrinfo_sub_node( jr_info0 ); + + if ( !ir_info0 ) goto check_header_node; + + ir_way0 = bli_thread_n_way( ir_info0 ); + ir_nt0 = bli_thread_num_threads( ir_info0 ); + + check_header_node: + + if ( !pa_info ) goto print_header; + + pa_way = bli_thread_n_way( pa_info ); + pa_nt = bli_thread_num_threads( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_header; + + jr_way = bli_thread_n_way( jr_info ); + jr_nt = bli_thread_num_threads( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_header; + + ir_way = bli_thread_n_way( ir_info ); + ir_nt = bli_thread_num_threads( ir_info ); + + print_header: printf( " jc kc pb ic pa jr ir\n" ); printf( "xx_nt: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n", @@ -343,26 +473,105 @@ void bli_l3_thrinfo_print_trsm_paths ( long )ir_way0, ( long )ir_way ); printf( "==================================================\n" ); - dim_t jc_comm_id; - dim_t pc_comm_id; - dim_t pb_comm_id; - dim_t ic_comm_id; - dim_t pa_comm_id0, pa_comm_id; - dim_t jr_comm_id0, jr_comm_id; - dim_t ir_comm_id0, ir_comm_id; - - dim_t jc_work_id; - dim_t pc_work_id; - dim_t pb_work_id; - dim_t ic_work_id; - dim_t pa_work_id0, pa_work_id; - dim_t jr_work_id0, jr_work_id; - dim_t ir_work_id0, ir_work_id; - - for ( gl_id = 0; gl_id < n_threads; ++gl_id ) + + for ( dim_t gl_id = 0; gl_id < n_threads; ++gl_id ) { jc_info = threads[gl_id]; +#if 1 + // NOTE: This cpp branch contains code that is safe to execute + // for small problems that are parallelized enough that one or + // more threads gets no work. + + dim_t jc_comm_id = -1, pc_comm_id = -1, pb_comm_id = -1, ic_comm_id = -1, + pa_comm_id = -1, jr_comm_id = -1, ir_comm_id = -1, + pa_comm_id0 = -1, jr_comm_id0 = -1, ir_comm_id0 = -1; + + dim_t jc_work_id = -1, pc_work_id = -1, pb_work_id = -1, ic_work_id = -1, + pa_work_id = -1, jr_work_id = -1, ir_work_id = -1, + pa_work_id0 = -1, jr_work_id0 = -1, ir_work_id0 = -1; + + if ( !jc_info ) goto print_thrinfo; + + jc_comm_id = bli_thread_ocomm_id( jc_info ); + jc_work_id = bli_thread_work_id( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); + + if ( !pc_info ) goto print_thrinfo; + + pc_comm_id = bli_thread_ocomm_id( pc_info ); + pc_work_id = bli_thread_work_id( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + + if ( !pb_info ) goto print_thrinfo; + + pb_comm_id = bli_thread_ocomm_id( pb_info ); + pb_work_id = bli_thread_work_id( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_thrinfo; + + ic_comm_id = bli_thread_ocomm_id( ic_info ); + ic_work_id = bli_thread_work_id( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + pa_info0 = bli_thrinfo_sub_prenode( ic_info ); + + // check_thrinfo_prenode: + + if ( !pa_info0 ) goto check_thrinfo_node; + + pa_comm_id0 = bli_thread_ocomm_id( pa_info0 ); + pa_work_id0 = bli_thread_work_id( pa_info0 ); + jr_info0 = bli_thrinfo_sub_node( pa_info0 ); + + if ( !jr_info0 ) goto check_thrinfo_node; + + jr_comm_id0 = bli_thread_ocomm_id( jr_info0 ); + jr_work_id0 = bli_thread_work_id( jr_info0 ); + ir_info0 = bli_thrinfo_sub_node( jr_info0 ); + + if ( !ir_info0 ) goto check_thrinfo_node; + + ir_comm_id0 = bli_thread_ocomm_id( ir_info0 ); + ir_work_id0 = bli_thread_work_id( ir_info0 ); + + check_thrinfo_node: + + if ( !pa_info ) goto print_thrinfo; + + pa_comm_id = bli_thread_ocomm_id( pa_info ); + pa_work_id = bli_thread_work_id( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_thrinfo; + + jr_comm_id = bli_thread_ocomm_id( jr_info ); + jr_work_id = bli_thread_work_id( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_thrinfo; + + ir_comm_id = bli_thread_ocomm_id( ir_info ); + ir_work_id = bli_thread_work_id( ir_info ); + + print_thrinfo: +#else + dim_t jc_comm_id; + dim_t pc_comm_id; + dim_t pb_comm_id; + dim_t ic_comm_id; + dim_t pa_comm_id0, pa_comm_id; + dim_t jr_comm_id0, jr_comm_id; + dim_t ir_comm_id0, ir_comm_id; + + dim_t jc_work_id; + dim_t pc_work_id; + dim_t pb_work_id; + dim_t ic_work_id; + dim_t pa_work_id0, pa_work_id; + dim_t jr_work_id0, jr_work_id; + dim_t ir_work_id0, ir_work_id; + // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads // may not fully build their thrinfo_t structures--specifically when the // dimension being parallelized is not large enough for each thread to have @@ -488,6 +697,7 @@ void bli_l3_thrinfo_print_trsm_paths } } } +#endif printf( "comm ids: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n", ( long )jc_comm_id, diff --git a/frame/3/bli_l3_thrinfo.h b/frame/3/bli_l3_thrinfo.h index 15d8faed60..37a3909fd6 100644 --- a/frame/3/bli_l3_thrinfo.h +++ b/frame/3/bli_l3_thrinfo.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -44,12 +44,12 @@ #define bli_gemm_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc ) #define bli_gemm_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc ) -// herk +// gemmt -// NOTE: The definition of bli_herk_get_next_?_upanel() does not need to +// NOTE: The definition of bli_gemmt_get_next_?_upanel() does not need to // change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR. -#define bli_herk_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc ) -#define bli_herk_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc ) +#define bli_gemmt_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc ) +#define bli_gemmt_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc ) // trmm @@ -93,6 +93,12 @@ void bli_l3_thrinfo_free thrinfo_t* thread ); +void bli_l3_sup_thrinfo_free + ( + rntm_t* rntm, + thrinfo_t* thread + ); + // ----------------------------------------------------------------------------- void bli_l3_thrinfo_create_root @@ -104,6 +110,20 @@ void bli_l3_thrinfo_create_root thrinfo_t** thread ); +void bli_l3_sup_thrinfo_create_root + ( + dim_t id, + thrcomm_t* gl_comm, + rntm_t* rntm, + thrinfo_t** thread + ); + +void bli_l3_sup_thrinfo_update_root + ( + rntm_t* rntm, + thrinfo_t* thread + ); + void bli_l3_thrinfo_print_gemm_paths ( thrinfo_t** threads diff --git a/frame/3/bli_l3_ukr_oapi.c b/frame/3/bli_l3_ukr_oapi.c index a8191b1aa8..e500bab713 100644 --- a/frame/3/bli_l3_ukr_oapi.c +++ b/frame/3/bli_l3_ukr_oapi.c @@ -51,6 +51,8 @@ void PASTEMAC0(opname) \ \ num_t dt = bli_obj_dt( c ); \ \ + dim_t m = bli_obj_length( c ); \ + dim_t n = bli_obj_width( c ); \ dim_t k = bli_obj_width( a ); \ void* buf_a = bli_obj_buffer_at_off( a ); \ void* buf_b = bli_obj_buffer_at_off( b ); \ @@ -69,12 +71,14 @@ void PASTEMAC0(opname) \ bli_auxinfo_set_is_b( 1, &data ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ + m, \ + n, \ k, \ buf_alpha, \ buf_a, \ @@ -107,6 +111,8 @@ void PASTEMAC0(opname) \ \ num_t dt = bli_obj_dt( c11 ); \ \ + dim_t m = bli_obj_length( c11 ); \ + dim_t n = bli_obj_width( c11 ); \ dim_t k = bli_obj_width( a1x ); \ void* buf_a1x = bli_obj_buffer_at_off( a1x ); \ void* buf_a11 = bli_obj_buffer_at_off( a11 ); \ @@ -130,12 +136,14 @@ void PASTEMAC0(opname) \ if ( bli_obj_is_lower( a11 ) ) \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnamel,_qfp)( dt ); \ \ f \ ( \ + m, \ + n, \ k, \ buf_alpha, \ buf_a1x, \ @@ -150,12 +158,14 @@ void PASTEMAC0(opname) \ else /* if ( bli_obj_is_upper( a11 ) ) */ \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnameu,_qfp)( dt ); \ \ f \ ( \ + m, \ + n, \ k, \ buf_alpha, \ buf_a1x, \ @@ -205,7 +215,7 @@ void PASTEMAC0(opname) \ if ( bli_obj_is_lower( a ) ) \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnamel,_qfp)( dt ); \ \ @@ -221,7 +231,7 @@ void PASTEMAC0(opname) \ else /* if ( bli_obj_is_upper( a ) ) */ \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnameu,_qfp)( dt ); \ \ diff --git a/frame/3/bli_l3_ukr_prot.h b/frame/3/bli_l3_ukr_prot.h index 80733897b8..677afc0202 100644 --- a/frame/3/bli_l3_ukr_prot.h +++ b/frame/3/bli_l3_ukr_prot.h @@ -36,16 +36,20 @@ // Define template prototypes for level-3 micro-kernels. // -#define GEMM_UKR_PROT( ctype, ch, opname ) \ +#define GEMM_UKR_PROT( ctype, ch, opname ) GEMM_UKR_PROT2(ctype, ctype, ch, opname) + +#define GEMM_UKR_PROT2( ctype_in, ctype_out, ch, opname ) \ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype_out* restrict alpha, \ + ctype_in* restrict a, \ + ctype_in* restrict b, \ + ctype_out* restrict beta, \ + ctype_out* restrict c, inc_t rs_c, inc_t cs_c, \ auxinfo_t* restrict data, \ cntx_t* restrict cntx \ ); @@ -55,6 +59,8 @@ void PASTEMAC(ch,opname) \ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a1x, \ diff --git a/frame/3/bli_l3_ukr_tapi.c b/frame/3/bli_l3_ukr_tapi.c index 67e33175b7..56eaf3f4ce 100644 --- a/frame/3/bli_l3_ukr_tapi.c +++ b/frame/3/bli_l3_ukr_tapi.c @@ -39,6 +39,8 @@ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - k, \ - alpha, \ - a, \ - b, \ - beta, \ - c, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + m, \ + n, \ + k, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR ) @@ -78,6 +83,8 @@ INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR ) \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a1x, \ @@ -98,17 +105,20 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - k, \ - alpha, \ - a1x, \ - a11, \ - bx1, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + m, \ + n, \ + k, \ + alpha, \ + a1x, \ + a11, \ + bx1, \ + b11, \ + c11, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR ) @@ -136,13 +146,14 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - a, \ - b, \ - c, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + a, \ + b, \ + c, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR ) diff --git a/frame/3/gemm/bli_gemm.h b/frame/3/gemm/bli_gemm.h index a6f8b4e1e0..ddd88e1633 100644 --- a/frame/3/gemm/bli_gemm.h +++ b/frame/3/gemm/bli_gemm.h @@ -34,7 +34,6 @@ #include "bli_gemm_cntl.h" #include "bli_gemm_front.h" -#include "bli_gemm_int.h" #include "bli_gemm_var.h" diff --git a/frame/3/gemm/bli_gemm_blk_var1.c b/frame/3/gemm/bli_gemm_blk_var1.c index b537119016..de077e5adc 100644 --- a/frame/3/gemm/bli_gemm_blk_var1.c +++ b/frame/3/gemm/bli_gemm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -77,7 +77,7 @@ void bli_gemm_blk_var1 i, b_alg, c, &c1 ); // Perform gemm subproblem. - bli_gemm_int + bli_l3_int ( &BLIS_ONE, &a1, diff --git a/frame/3/gemm/bli_gemm_blk_var2.c b/frame/3/gemm/bli_gemm_blk_var2.c index cd5a833f60..53943e47cd 100644 --- a/frame/3/gemm/bli_gemm_blk_var2.c +++ b/frame/3/gemm/bli_gemm_blk_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -77,7 +77,7 @@ void bli_gemm_blk_var2 i, b_alg, c, &c1 ); // Perform gemm subproblem. - bli_gemm_int + bli_l3_int ( &BLIS_ONE, a, diff --git a/frame/3/gemm/bli_gemm_blk_var3.c b/frame/3/gemm/bli_gemm_blk_var3.c index 0c236f6d1f..28029777de 100644 --- a/frame/3/gemm/bli_gemm_blk_var3.c +++ b/frame/3/gemm/bli_gemm_blk_var3.c @@ -71,7 +71,7 @@ void bli_gemm_blk_var3 i, b_alg, b, &b1 ); // Perform gemm subproblem. - bli_gemm_int + bli_l3_int ( &BLIS_ONE, &a1, @@ -84,7 +84,7 @@ void bli_gemm_blk_var3 bli_thrinfo_sub_node( thread ) ); - bli_thread_obarrier( bli_thrinfo_sub_node( thread ) ); + bli_thread_barrier( bli_thrinfo_sub_node( thread ) ); // This variant executes multiple rank-k updates. Therefore, if the // internal beta scalar on matrix C is non-zero, we must use it @@ -93,7 +93,7 @@ void bli_gemm_blk_var3 // can simply overwrite the internal beta scalar with BLIS_ONE once // it has been used in the first iteration. However... - // Unlike variant 3 of gemm and herk, which reset the internal scalar + // Unlike variant 3 of gemm and gemmt, which reset the internal scalar // on C at the end of the first iteration so that subsequent iterations // do not erroneously apply beta more than once, it is important that // this behavior not be applied to trmm. That is because the order of diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index 67c71e798d..052c812a33 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { - return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b ); + return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker ); } // ----------------------------------------------------------------------------- @@ -53,22 +54,21 @@ cntl_t* bli_gemmbp_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { - void* macro_kernel_fp; - void* packa_fp; - void* packb_fp; + void_fp macro_kernel_fp; - // Use the function pointers to the macrokernels that use slab - // assignment of micropanels to threads in the jr and ir loops. + // Choose the default macrokernel based on the operation family... if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2; - else if ( family == BLIS_HERK ) macro_kernel_fp = bli_herk_x_ker_var2; + else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; else /* should never execute */ macro_kernel_fp = NULL; - packa_fp = bli_packm_blk_var1; - packb_fp = bli_packm_blk_var1; + // ...unless a non-NULL kernel function pointer is passed in, in which + // case we use that instead. + if ( ker ) macro_kernel_fp = ker; // Create two nodes for the macro-kernel. cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node @@ -93,8 +93,7 @@ cntl_t* bli_gemmbp_cntl_create cntl_t* gemm_cntl_packa = bli_packm_cntl_create_node ( rntm, - bli_gemm_packa, // pack the left-hand operand - packa_fp, + bli_l3_packa, // pack the left-hand operand BLIS_MR, BLIS_KR, FALSE, // do NOT invert diagonal @@ -119,10 +118,9 @@ cntl_t* bli_gemmbp_cntl_create cntl_t* gemm_cntl_packb = bli_packm_cntl_create_node ( rntm, - bli_gemm_packb, // pack the right-hand operand - packb_fp, - BLIS_KR, + bli_l3_packb, // pack the right-hand operand BLIS_NR, + BLIS_KR, FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? FALSE, // reverse iteration if lower? @@ -165,10 +163,10 @@ cntl_t* bli_gemmpb_cntl_create opid_t family ) { - void* macro_kernel_p = bli_gemm_ker_var1; + void_fp macro_kernel_p = bli_gemm_ker_var1; - // Change the macro-kernel if the operation family is herk or trmm. - //if ( family == BLIS_HERK ) macro_kernel_p = bli_herk_x_ker_var2; + // Change the macro-kernel if the operation family is gemmt or trmm. + //if ( family == BLIS_GEMMT ) macro_kernel_p = bli_gemmt_x_ker_var2; //else if ( family == BLIS_TRMM ) macro_kernel_p = bli_trmm_xx_ker_var2; // Create two nodes for the macro-kernel. @@ -194,8 +192,8 @@ cntl_t* bli_gemmpb_cntl_create ( bli_gemm_packb, // pack the right-hand operand bli_packm_blk_var1, - BLIS_KR, BLIS_MR, + BLIS_KR, FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? FALSE, // reverse iteration if lower? @@ -270,7 +268,7 @@ cntl_t* bli_gemm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/frame/3/gemm/bli_gemm_cntl.h b/frame/3/gemm/bli_gemm_cntl.h index e19384a51a..5fa213ac41 100644 --- a/frame/3/gemm/bli_gemm_cntl.h +++ b/frame/3/gemm/bli_gemm_cntl.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,7 +38,8 @@ cntl_t* bli_gemm_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); // ----------------------------------------------------------------------------- @@ -48,7 +49,8 @@ cntl_t* bli_gemmbp_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); #if 0 @@ -74,7 +76,7 @@ cntl_t* bli_gemm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 97bc5c5d03..4ff45036fe 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,32 +53,64 @@ void bli_gemm_front obj_t b_local; obj_t c_local; -#ifdef BLIS_ENABLE_SMALL_MATRIX - // Only handle small problems separately for homogeneous datatypes. - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && - bli_obj_dt( a ) == bli_obj_dt( c ) ) + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) { - gint_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); - if ( status == BLIS_SUCCESS ) return; + return; } -#endif - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) { bli_scalm( beta, c ); return; } +#if 0 +#ifdef BLIS_ENABLE_SMALL_MATRIX + // Only handle small problems separately for homogeneous datatypes. + if ( bli_obj_dt( a ) == bli_obj_dt( b ) && + bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) + { + err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); + if ( status == BLIS_SUCCESS ) return; + } +#endif +#endif + // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + #ifdef BLIS_ENABLE_GEMM_MD cntx_t cntx_local; @@ -98,24 +130,8 @@ void bli_gemm_front // is adjusted to point to cntx_local.) bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); } - //else // homogeneous datatypes #endif - // Load the pack schemas from the context and embed them into the objects - // for A and B. (Native contexts are initialized with the correct pack - // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would - // have made a copy and modified the schemas, so reading them from the - // context should be a safe bet at this point.) This is a sort of hack for - // communicating the desired pack schemas for to bli_gemm_cntl_create() - // (via bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows - // us to subsequently access the schemas from the control tree, which - // hopefully reduces some confusion, particularly in bli_packm_init(). - const pack_t schema_a = bli_cntx_schema_a_block( cntx ); - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Next, we handle the possibility of needing to typecast alpha to the // computation datatype and/or beta to the storage datatype of C. @@ -134,30 +150,6 @@ void bli_gemm_front alpha = &BLIS_ONE; beta = &BLIS_ONE; -#ifdef BLIS_ENABLE_GEMM_MD - // Don't perform the following optimization for ccr or crc cases, as - // those cases are sensitive to the ukernel storage preference (ie: - // transposing the operation would break them). - if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && - !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) -#endif - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_swap( &a_local, &b_local ); - - bli_obj_induce_trans( &a_local ); - bli_obj_induce_trans( &b_local ); - bli_obj_induce_trans( &c_local ); - - // We must also swap the pack schemas, which were set by bli_gemm_md() - // or the inlined code above. - bli_obj_swap_pack_schemas( &a_local, &b_local ); - } - // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -184,15 +176,15 @@ void bli_gemm_front // of the ccr or crc cases. // Then, after the computation is complete, this matrix will be copied // or accumulated back to C. - const bool_t is_ccr_mismatch = + const bool is_ccr_mismatch = ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && !bli_obj_is_col_stored( &c_local ) ); - const bool_t is_crc_mismatch = + const bool is_crc_mismatch = ( bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) && !bli_obj_is_row_stored( &c_local ) ); - obj_t ct; - bool_t use_ct = FALSE; + obj_t ct; + bool use_ct = FALSE; // FGVZ: Consider adding another guard here that only creates and uses a // temporary matrix for accumulation if k < c * kc, where c is some small @@ -260,7 +252,7 @@ void bli_gemm_front // Invoke the internal back-end via the thread handler. bli_l3_thread_decorator ( - bli_gemm_int, + bli_l3_int, BLIS_GEMM, // operation family id alpha, &a_local, @@ -291,90 +283,3 @@ void bli_gemm_front #endif } -// ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool_t a_is_real = bli_obj_is_real( a ); - const bool_t a_is_comp = bli_obj_is_complex( a ); - const bool_t b_is_real = bli_obj_is_real( b ); - const bool_t b_is_comp = bli_obj_is_complex( b ); - const bool_t c_is_real = bli_obj_is_real( c ); - const bool_t c_is_comp = bli_obj_is_complex( c ); - - const bool_t a_is_single = bli_obj_is_single_prec( a ); - const bool_t a_is_double = bli_obj_is_double_prec( a ); - const bool_t b_is_single = bli_obj_is_single_prec( b ); - const bool_t b_is_double = bli_obj_is_double_prec( b ); - const bool_t c_is_single = bli_obj_is_single_prec( c ); - const bool_t c_is_double = bli_obj_is_double_prec( c ); - - const bool_t comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool_t comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool_t mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif - diff --git a/frame/3/gemm/bli_gemm_front.h b/frame/3/gemm/bli_gemm_front.h index ba65bab8db..2728ce8f7f 100644 --- a/frame/3/gemm/bli_gemm_front.h +++ b/frame/3/gemm/bli_gemm_front.h @@ -44,6 +44,7 @@ void bli_gemm_front cntl_t* cntl ); +#ifdef BLIS_ENABLE_SMALL_MATRIX err_t bli_gemm_small ( obj_t* alpha, @@ -54,3 +55,5 @@ err_t bli_gemm_small cntx_t* cntx, cntl_t* cntl ); +#endif + diff --git a/frame/3/gemm/bli_gemm_int.c b/frame/3/gemm/bli_gemm_int.c deleted file mode 100644 index 25a6215dfd..0000000000 --- a/frame/3/gemm/bli_gemm_int.c +++ /dev/null @@ -1,135 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -void bli_gemm_int - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - obj_t a_local; - obj_t b_local; - obj_t c_local; - gemm_var_oft f; - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_basic_check( alpha, a, b, beta, c, cntx ); - - // If C has a zero dimension, return early. - if ( bli_obj_has_zero_dim( c ) ) return; - - // If A or B has a zero dimension, scale C by beta and return early. - if ( bli_obj_has_zero_dim( a ) || - bli_obj_has_zero_dim( b ) ) - { - if ( bli_thread_am_ochief( thread ) ) - bli_scalm( beta, c ); - bli_thread_obarrier( thread ); - return; - } - - // If A or B is marked as being filled with zeros, scale C by beta and - // return early. - if ( bli_obj_is_zeros( a ) || - bli_obj_is_zeros( b ) ) - { - // This should never execute. - bli_abort(); - - if ( bli_thread_am_ochief( thread ) ) - bli_scalm( beta, c ); - bli_thread_obarrier( thread ); - return; - } - - // Alias A, B, and C in case we need to update attached scalars. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - - // If alpha is non-unit, typecast and apply it to the scalar attached - // to B. - if ( !bli_obj_equals( alpha, &BLIS_ONE ) ) - { - bli_obj_scalar_apply_scalar( alpha, &b_local ); - } - - // If beta is non-unit, typecast and apply it to the scalar attached - // to C. - if ( !bli_obj_equals( beta, &BLIS_ONE ) ) - { - bli_obj_scalar_apply_scalar( beta, &c_local ); - } - - // Create the next node in the thrinfo_t structure. - bli_thrinfo_grow( rntm, cntl, thread ); - - // Extract the function pointer from the current control tree node. - f = bli_cntl_var_func( cntl ); - - // Somewhat hackish support for 4m1b method implementation. - { - ind_t im = bli_cntx_method( cntx ); - - if ( im != BLIS_NAT ) - { - if ( im == BLIS_4M1B ) - if ( f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2; - } - } - - // Invoke the variant. - f - ( - &a_local, - &b_local, - &c_local, - cntx, - rntm, - cntl, - thread - ); -} - diff --git a/frame/3/gemm/bli_gemm_ker_var1.c b/frame/3/gemm/bli_gemm_ker_var1.c index 4dcffd279e..096091e765 100644 --- a/frame/3/gemm/bli_gemm_ker_var1.c +++ b/frame/3/gemm/bli_gemm_ker_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 41bb3f4552..6de361194d 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,28 +35,44 @@ #include "blis.h" -#define FUNCPTR_T gemm_fp +typedef void (*xpbys_mxn_vft) + ( + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); -typedef void (*FUNCPTR_T) - ( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); +#undef GENTFUNC2 +#define GENTFUNC2(ctypex,ctypey,chx,chy,op) \ +\ +void PASTEMAC2(chx,chy,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctypex* restrict x_cast = x; \ + ctypey* restrict b_cast = b; \ + ctypey* restrict y_cast = y; \ +\ + PASTEMAC3(chx,chy,chy,xpbys_mxn) \ + ( \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} -static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2); +INSERT_GENTFUNC2_BASIC0(xbpys_mxn_fn); +INSERT_GENTFUNC2_MIXDP0(xbpys_mxn_fn); + +static xpbys_mxn_vft GENARRAY2_ALL(xbpys_mxn, xbpys_mxn_fn); void bli_gemm_ker_var2 @@ -70,23 +86,8 @@ void bli_gemm_ker_var2 thrinfo_t* thread ) { -#ifdef BLIS_ENABLE_GEMM_MD - // By now, A and B have been packed and cast to the execution precision. - // In most cases, such as when storage precision of C differs from the - // execution precision, we utilize the mixed datatype code path. However, - // a few cases still fall within this kernel, such as mixed domain with - // equal precision (ccr, crc, rcc), hence those expressions being disabled - // in the conditional below. - if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) || - //( bli_obj_domain( c ) != bli_obj_domain( b ) ) || - ( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) ) - { - bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread ); - return; - } -#endif - num_t dt_exec = bli_obj_exec_dt( c ); + num_t dt_c = bli_obj_dt( c ); pack_t schema_a = bli_obj_pack_schema( a ); pack_t schema_b = bli_obj_pack_schema( b ); @@ -95,50 +96,55 @@ void bli_gemm_ker_var2 dim_t n = bli_obj_width( c ); dim_t k = bli_obj_width( a ); - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); + char* a_cast = bli_obj_buffer_at_off( a ); inc_t is_a = bli_obj_imag_stride( a ); dim_t pd_a = bli_obj_panel_dim( a ); inc_t ps_a = bli_obj_panel_stride( a ); - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); + char* b_cast = bli_obj_buffer_at_off( b ); inc_t is_b = bli_obj_imag_stride( b ); dim_t pd_b = bli_obj_panel_dim( b ); inc_t ps_b = bli_obj_panel_stride( b ); - void* buf_c = bli_obj_buffer_at_off( c ); + char* c_cast = bli_obj_buffer_at_off( c ); inc_t rs_c = bli_obj_row_stride( c ); inc_t cs_c = bli_obj_col_stride( c ); - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a; + obj_t scalar_b; bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( b, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b ); // Grab the addresses of the internal scalar buffers for the scalar // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); + // NOTE: We know that scalar_b is of type dt_exec due to the above code + // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt_c + // due to the casting in the implementation of bli_obj_scalar_attach(). + char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b ); + char* beta_cast = bli_obj_internal_scalar_buffer( c ); // If 1m is being employed on a column- or row-stored matrix with a // real-valued beta, we can use the real domain macro-kernel, which // eliminates a little overhead associated with the 1m virtual // micro-kernel. + // Only employ this optimization if the storage datatype of C is + // equal to the execution/computation datatype. #if 1 if ( bli_cntx_method( cntx ) == BLIS_1M ) { bli_gemm_ind_recast_1m_params ( &dt_exec, + &dt_c, schema_a, c, &m, &n, &k, @@ -151,263 +157,211 @@ void bli_gemm_ker_var2 #ifdef BLIS_ENABLE_GEMM_MD // Tweak parameters in select mixed domain cases (rcc, crc, ccr). - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - bli_obj_dt( c ), - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); + if ( bli_cntx_method( cntx ) == BLIS_NAT ) + { + bli_gemm_md_ker_var2_recast + ( + &dt_exec, + bli_obj_dt( a ), + bli_obj_dt( b ), + &dt_c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + c, + &rs_c, &cs_c + ); + } #endif - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} + siz_t dt_size = bli_dt_size( dt_exec ); + siz_t dt_c_size = bli_dt_size( dt_c ); + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + //const dim_t PACKMR = cs_a; + //const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + + // Query the params field from the obj_t. If it is non-NULL, grab the ukr + // field of the params struct. If that function pointer is non-NULL, use it + // as our microkernel instead of the default microkernel queried from the + // cntx above. + gemm_ker_params_t* params = bli_obj_ker_params( c ); + gemm_ukr_vft user_ukr = params ? params->ukr : NULL; + if ( user_ukr ) gemm_ukr = user_ukr; + + // Temporary C buffer for edge cases. Note that the strides of this + // temporary buffer are set so that they match the storage of the + // original C matrix. For example, if C is column-stored, ct will be + // column-stored as well. + char ct[ BLIS_STACK_BUF_MAX_SIZE ] + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); + + // + // Assumptions/assertions: + // rs_a == 1 + // cs_a == PACKMR + // pd_a == MR + // ps_a == stride to next micro-panel of A + // rs_b == PACKNR + // cs_b == 1 + // pd_b == NR + // ps_b == stride to next micro-panel of B + // rs_c == (no assumptions) + // cs_c == (no assumptions) + // + + // Compute number of primary and leftover components of the m and n + // dimensions. + dim_t n_iter = n / NR; + dim_t n_left = n % NR; + + dim_t m_iter = m / MR; + dim_t m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + // Determine some increments used to step through A, B, and C. + inc_t rstep_a = ps_a * dt_size; + + inc_t cstep_b = ps_b * dt_size; + + inc_t rstep_c = rs_c * MR * dt_c_size; + inc_t cstep_c = cs_c * NR * dt_c_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // Save the imaginary stride of A and B to the auxinfo_t object. + bli_auxinfo_set_is_a( is_a, &aux ); + bli_auxinfo_set_is_b( is_b, &aux ); + + // Save the virtual microkernel address and the params. + bli_auxinfo_set_ukr( gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for each loop. + dim_t jr_nt = bli_thread_n_way( thread ); + dim_t jr_tid = bli_thread_work_id( thread ); + dim_t ir_nt = bli_thread_n_way( caucus ); + dim_t ir_tid = bli_thread_work_id( caucus ); + + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + + // Determine the thread range and increment for the 2nd and 1st loops. + // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // slab or round-robin partitioning was requested at configure-time. + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + // Loop over the n dimension (NR columns at a time). + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) + { + char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + char* b2 = b1; + + // Loop over the m dimension (MR rows at a time). + for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) + { + char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + // Edge case handling now occurs within the microkernel itself, but + // we must still explicitly accumulate to a temporary microtile in + // situations where a virtual microkernel is being used, such as + // during the 1m method or some cases of mixed datatypes. + if ( dt_exec == dt_c ) + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + // Accumulate to C with type-casting. + xbpys_mxn[ dt_exec ][ dt_c ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c11, rs_c, cs_c + ); + } + } + } -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Determine the thread range and increment for the 2nd and 1st loops. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ -\ /* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); +*/ } -INSERT_GENTFUNC_BASIC0( gemm_ker_var2 ) - diff --git a/frame/3/gemm/bli_gemm_ker_var2_md.c b/frame/3/gemm/bli_gemm_ker_var2_md.c deleted file mode 100644 index 3428be9b4f..0000000000 --- a/frame/3/gemm/bli_gemm_ker_var2_md.c +++ /dev/null @@ -1,450 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#ifdef BLIS_ENABLE_GEMM_MD - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T) - ( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY2_ALL(ftypes,gemm_ker_var2_md); - - -void bli_gemm_ker_var2_md - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - num_t dt_c = bli_obj_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - // NOTE: We know that the internal scalars of A and B are already of the - // target datatypes because the necessary typecasting would have already - // taken place during bli_packm_init(). - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - // NOTE: We know that scalar_b is of type dt_exec due to the above code - // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, - // and we know that the internal scalar in C is already of the type dt_c - // due to the casting in the implementation of bli_obj_scalar_attach(). - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - -#if 0 - // NOTE: Turns out that this optimization will never be employed since - // currently bli_gemm_ker_var2_md() is only called when the storage - // datatype of C differs from the execution/computation datatype, and - // this optimization would only make sense if they are equal. - - // If 1m is being employed on a column- or row-stored matrix with a - // real-valued beta, we can use the real domain macro-kernel, which - // eliminates a little overhead associated with the 1m virtual - // micro-kernel. - if ( bli_cntx_method( cntx ) == BLIS_1M ) - { - // Only employ this optimization if the storage datatype of C is - // equal to the execution/computation datatype. - if ( dt_c == dt_exec ) - { - bli_gemm_ind_recast_1m_params - ( - &dt_exec, - schema_a, - c, - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - &rs_c, &cs_c - ); - } - } -#endif - - // Tweak parameters in select mixed domain cases (rcc, crc, ccr). - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - bli_obj_dt( c ), - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_c][dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC2 -#define GENTFUNC2( ctype_c, ctype_e, chc, che, varname ) \ -\ -void PASTEMAC2(chc,che,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dte = PASTEMAC(che,type); \ - /*const num_t dtc = PASTEMAC(chc,type);*/ \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(che,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dte, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype_e ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_e ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dte, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype_e* restrict zero = PASTEMAC(che,0); \ - ctype_e* restrict a_cast = a; \ - ctype_e* restrict b_cast = b; \ - ctype_c* restrict c_cast = c; \ - ctype_e* restrict alpha_cast = alpha; \ - ctype_c* restrict beta_cast = beta; \ - ctype_e* restrict b1; \ - ctype_c* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(che,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Determine the thread range and increment for the 2nd and 1st loops. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype_e* restrict a1; \ - ctype_c* restrict c11; \ - ctype_e* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype_e* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Always save the micropanel product to the local microtile and - then accumulate it into C via the xpbys_mxn macro. */ \ - /*if ( 1 )*/ \ - { \ - /*bli_auxinfo_set_dt_on_output( dte, &aux );*/ \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the microtile of C and add the result from above. */ \ - PASTEMAC3(che,chc,chc,xpbys_mxn) \ - ( \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c \ - ); \ - } \ -/* - else if ( m_cur == MR && n_cur == NR ) \ - { \ - bli_auxinfo_set_dt_on_output( dtc, &aux ); \ -\ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - ( ctype_e* )beta_cast, \ - ( ctype_e* )c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - bli_auxinfo_set_dt_on_output( dte, &aux ); \ -\ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - PASTEMAC3(che,chc,chc,xpbys_mxn) \ - ( \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c \ - ); \ - } \ -*/ \ - } \ - } \ -\ -/* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNC2_BASIC0( gemm_ker_var2_md ) -INSERT_GENTFUNC2_MIXDP0( gemm_ker_var2_md ) - -#endif diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index b08bea2c4e..e257cdf287 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,12 +49,12 @@ void bli_gemm_md { mddm_t doms; - const bool_t a_is_real = bli_obj_is_real( a ); - const bool_t a_is_comp = bli_obj_is_complex( a ); - const bool_t b_is_real = bli_obj_is_real( b ); - const bool_t b_is_comp = bli_obj_is_complex( b ); - const bool_t c_is_real = bli_obj_is_real( c ); - const bool_t c_is_comp = bli_obj_is_complex( c ); + const bool a_is_real = bli_obj_is_real( a ); + const bool a_is_comp = bli_obj_is_complex( a ); + const bool b_is_real = bli_obj_is_real( b ); + const bool b_is_comp = bli_obj_is_complex( b ); + const bool c_is_real = bli_obj_is_real( c ); + const bool c_is_comp = bli_obj_is_complex( c ); if ( c_is_real && a_is_real && b_is_real ) { @@ -171,8 +171,8 @@ mddm_t bli_gemm_md_ccr // is equal to the real projection of the execution datatype, and use // that computation datatype to query the corresponding ukernel output // preference. - const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool_t row_pref + const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); + const bool row_pref = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, *cntx ); // We can only perform this case of mixed-domain gemm, C += A*B where @@ -187,6 +187,10 @@ mddm_t bli_gemm_md_ccr bli_obj_induce_trans( b ); bli_obj_induce_trans( c ); + // We must swap the pack schemas because the schemas were set before + // the objects were swapped. + bli_obj_swap_pack_schemas( a, b ); + return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); } @@ -230,7 +234,7 @@ mddm_t bli_gemm_md_ccr bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mc ); bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mc ); - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); @@ -272,8 +276,8 @@ mddm_t bli_gemm_md_crc // is equal to the real projection of the execution datatype, and use // that computation datatype to query the corresponding ukernel output // preference. - const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool_t col_pref + const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, *cntx ); // We can only perform this case of mixed-domain gemm, C += A*B where @@ -288,6 +292,10 @@ mddm_t bli_gemm_md_crc bli_obj_induce_trans( b ); bli_obj_induce_trans( c ); + // We must swap the pack schemas because the schemas were set before + // the objects were swapped. + bli_obj_swap_pack_schemas( a, b ); + return bli_gemm_md_ccr( a, b, beta, c, cntx_local, cntx ); } @@ -331,7 +339,7 @@ mddm_t bli_gemm_md_crc bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_nc ); bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_nc ); - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); @@ -405,8 +413,8 @@ mddm_t bli_gemm_md_rcc // Use the 1r pack schema for both A and B with the conjugation // of A or B toggled (to produce ar * br - ai * bi). - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_1R, *cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_1R, *cntx ); + bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS_1R, a ); + bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS_1R, b ); bli_obj_toggle_conj( b ); @@ -485,7 +493,7 @@ mddm_t bli_gemm_md_crr } #endif - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -523,7 +531,7 @@ mddm_t bli_gemm_md_rcr // Overwrite the complex obj_t with its real-only alias. *a = a_real; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -561,7 +569,7 @@ mddm_t bli_gemm_md_rrc // Overwrite the complex obj_t with its real-only alias. *b = b_real; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -591,7 +599,7 @@ mddm_t bli_gemm_md_rrr doms.comp = BLIS_REAL; doms.exec = BLIS_REAL; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -621,248 +629,10 @@ mddm_t bli_gemm_md_ccc doms.comp = BLIS_COMPLEX; doms.exec = BLIS_COMPLEX; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; } -// ----------------------------------------------------------------------------- - -#if 0 -void bli_gemm_md_front - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t a_local; - obj_t b_local; - obj_t c_local; - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - return; - } - - // Alias A, B, and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_swap( &a_local, &b_local ); - - bli_obj_induce_trans( &a_local ); - bli_obj_induce_trans( &b_local ); - bli_obj_induce_trans( &c_local ); - } - - cntx_t cntx_local; - - // Handle mixed domain cases in bli_gemm_md(), which may modify - // the objects or the context. (If the context is modified, cntx - // is adjusted to point to cntx_local.) - bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); - - // Record the threading for each level within the context. - bli_rntm_set_ways_for_op - ( - BLIS_GEMM, - BLIS_LEFT, // ignored for gemm/hemm/symm - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // Invoke the internal back-end via the thread handler. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_GEMM, // operation family id - alpha, - &a_local, - &b_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); -} - -// ----------------------------------------------------------------------------- - -void bli_gemm_md_zgemm - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t a_local; - obj_t b_local; - obj_t c_local; - -#if 1 - obj_t am, bm, cm; - obj_t* c_orig; - - //if ( is_md == TRUE ) - { - //num_t dt_c2 = bli_obj_dt( c ); - //num_t dt_c1 = bli_dt_proj_to_complex( dt_c2 ); - //num_t dt_c = bli_dt_proj_to_double_prec( dt_c1 ); - //num_t dt_c = bli_obj_dt_proj_to_complex( c ); - num_t dt_c = BLIS_DCOMPLEX; - - if ( bli_obj_is_single_prec( c ) ) dt_c = BLIS_SCOMPLEX; - else dt_c = BLIS_DCOMPLEX; - - if ( bli_obj_is_real( a ) && - bli_obj_is_real( b ) && - bli_obj_is_real( c ) ) dt_c = bli_dt_proj_to_real( dt_c ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width_after_trans( a ); - - bli_obj_create( dt_c, m, k, 0, 0, &am ); - bli_obj_create( dt_c, k, n, 0, 0, &bm ); - bli_obj_create( dt_c, m, n, 0, 0, &cm ); - - //bli_projm( a, &am ); - //bli_projm( b, &bm ); - //bli_projm( c, &cm ); - bli_castm( a, &am ); - bli_castm( b, &bm ); - bli_castm( c, &cm ); - - c_orig = c; - - a = &am; - b = &bm; - c = &cm; - } -#endif - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - return; - } - - // Alias A, B, and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_swap( &a_local, &b_local ); - - bli_obj_induce_trans( &a_local ); - bli_obj_induce_trans( &b_local ); - bli_obj_induce_trans( &c_local ); - } - - { - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - } - - // Parse and interpret the contents of the rntm_t object to properly - // set the ways of parallelism for each loop, and then make any - // additional modifications necessary for the current operation. - bli_rntm_set_ways_for_op - ( - BLIS_GEMM, - BLIS_LEFT, // ignored for gemm/hemm/symm - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // Invoke the internal back-end via the thread handler. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_GEMM, // operation family id - alpha, - &a_local, - &b_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); - -#if 1 - //if ( is_md == TRUE ) - { - //bli_projm( &cm, c_orig ); - bli_castm( &cm, c_orig ); - - bli_obj_free( &am ); - bli_obj_free( &bm ); - bli_obj_free( &cm ); - } -#endif -} -#endif - #endif diff --git a/frame/3/gemm/bli_gemm_md.h b/frame/3/gemm/bli_gemm_md.h index 057eb0a1d8..751e271eaf 100644 --- a/frame/3/gemm/bli_gemm_md.h +++ b/frame/3/gemm/bli_gemm_md.h @@ -87,9 +87,9 @@ void bli_gemm_md_zgemm // ----------------------------------------------------------------------------- -static bool_t bli_gemm_md_is_crr( obj_t* a, obj_t* b, obj_t* c ) +BLIS_INLINE bool bli_gemm_md_is_crr( obj_t* a, obj_t* b, obj_t* c ) { - bool_t r_val = FALSE; + bool r_val = FALSE; // NOTE: The last conditional subexpression is necessary if/when we // allow the user to specify the computation domain. (The computation @@ -107,9 +107,9 @@ static bool_t bli_gemm_md_is_crr( obj_t* a, obj_t* b, obj_t* c ) return r_val; } -static bool_t bli_gemm_md_is_ccr( obj_t* a, obj_t* b, obj_t* c ) +BLIS_INLINE bool bli_gemm_md_is_ccr( obj_t* a, obj_t* b, obj_t* c ) { - bool_t r_val = FALSE; + bool r_val = FALSE; // NOTE: The last conditional subexpression is necessary if/when we // allow the user to specify the computation domain. (The computation @@ -127,9 +127,9 @@ static bool_t bli_gemm_md_is_ccr( obj_t* a, obj_t* b, obj_t* c ) return r_val; } -static bool_t bli_gemm_md_is_crc( obj_t* a, obj_t* b, obj_t* c ) +BLIS_INLINE bool bli_gemm_md_is_crc( obj_t* a, obj_t* b, obj_t* c ) { - bool_t r_val = FALSE; + bool r_val = FALSE; // NOTE: The last conditional subexpression is necessary if/when we // allow the user to specify the computation domain. (The computation @@ -149,12 +149,12 @@ static bool_t bli_gemm_md_is_crc( obj_t* a, obj_t* b, obj_t* c ) // ----------------------------------------------------------------------------- -static void bli_gemm_md_ker_var2_recast +BLIS_INLINE void bli_gemm_md_ker_var2_recast ( num_t* dt_comp, num_t dt_a, num_t dt_b, - num_t dt_c, + num_t* dt_c, dim_t* m, dim_t* n, dim_t* k, @@ -164,7 +164,7 @@ static void bli_gemm_md_ker_var2_recast inc_t* rs_c, inc_t* cs_c ) { - if ( bli_is_real( dt_c ) && + if ( bli_is_real( *dt_c ) && bli_is_complex( dt_a ) && bli_is_complex( dt_b ) ) { @@ -177,7 +177,7 @@ static void bli_gemm_md_ker_var2_recast *ps_a *= 2; *ps_b *= 2; } - else if ( bli_is_complex( dt_c ) && + else if ( bli_is_complex( *dt_c ) && bli_is_real( dt_a ) && bli_is_complex( dt_b ) ) { @@ -197,6 +197,7 @@ static void bli_gemm_md_ker_var2_recast // to the real virtual microkernel slots of the context) instead of // the complex macrokernel and c2r virtual microkernel. *dt_comp = bli_dt_proj_to_real( *dt_comp ); + *dt_c = bli_dt_proj_to_real( *dt_c ); *n *= 2; *pd_b *= 2; *ps_b *= 2; *rs_c *= 2; @@ -211,7 +212,7 @@ static void bli_gemm_md_ker_var2_recast *ps_a /= 2; } } - else if ( bli_is_complex( dt_c ) && + else if ( bli_is_complex( *dt_c ) && bli_is_complex( dt_a ) && bli_is_real( dt_b ) ) { @@ -231,6 +232,7 @@ static void bli_gemm_md_ker_var2_recast // to the real virtual microkernel slots of the context) instead of // the complex macrokernel and c2r virtual microkernel. *dt_comp = bli_dt_proj_to_real( *dt_comp ); + *dt_c = bli_dt_proj_to_real( *dt_c ); *m *= 2; *pd_a *= 2; *ps_a *= 2; *cs_c *= 2; @@ -274,54 +276,3 @@ static void bli_gemm_md_ker_var2_recast #endif } -// ----------------------------------------------------------------------------- - -// -// Prototype object-based interfaces. -// - -#undef GENPROT -#define GENPROT( opname ) \ -\ -void PASTEMAC0(opname) \ - ( \ - obj_t* a, \ - obj_t* b, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - cntl_t* cntl, \ - thrinfo_t* thread \ - ); - -GENPROT( gemm_ker_var2_md ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT2 -#define GENTPROT2( ctype_c, ctype_e, chc, che, varname ) \ -\ -void PASTEMAC2(chc,che,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT2_BASIC0( gemm_ker_var2_md ) -INSERT_GENTPROT2_MIXDP0( gemm_ker_var2_md ) - diff --git a/frame/3/gemm/bli_gemm_md_c2r_ref.c b/frame/3/gemm/bli_gemm_md_c2r_ref.c index 6198d85b27..bbd9190a9a 100644 --- a/frame/3/gemm/bli_gemm_md_c2r_ref.c +++ b/frame/3/gemm/bli_gemm_md_c2r_ref.c @@ -41,6 +41,8 @@ \ void PASTEMAC2(ch,opname,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -56,11 +58,14 @@ void PASTEMAC2(ch,opname,suf) \ \ PASTECH(chr,gemm_ukr_ft) \ rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ - const bool_t col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ - const bool_t row_pref = !col_pref; \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ + const bool row_pref = !col_pref; \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + dim_t mr_r = mr; \ + dim_t nr_r = nr; \ \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype_r ) ] \ @@ -81,12 +86,15 @@ void PASTEMAC2(ch,opname,suf) \ \ ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \ ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \ +\ + dim_t m_use; \ + dim_t n_use; \ \ ctype_r* c_use; \ inc_t rs_c_use; \ inc_t cs_c_use; \ \ - bool_t using_ct; \ + bool using_ct; \ \ /* This virtual microkernel is used by ccr and crc mixed-domain cases when any of the following conditions are met: @@ -146,17 +154,16 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ rs_c_use = rs_ct; \ cs_c_use = cs_ct; \ \ - /* Convert the strides from being in units of complex elements to - be in units of real elements. Note that we don't need to check for - general storage here because that case corresponds to the scenario - where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ - if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \ - else rs_c_use *= 2; \ -\ + /* Convert the strides and corresponding microtile dimension from being + in units of complex elements to be in units of real elements. */ \ + if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \ + else { rs_c_use *= 2; nr_r *= 2; }\ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k, \ alpha_r, \ a_r, \ @@ -166,14 +173,12 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ data, \ cntx \ ); \ -\ - dim_t i, j; \ \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -181,8 +186,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -190,8 +195,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -207,17 +212,19 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ c_use = ( ctype_r* )c; \ rs_c_use = rs_c; \ cs_c_use = cs_c; \ + m_use = m; \ + n_use = n; \ \ - /* Convert the strides from being in units of complex elements to - be in units of real elements. Note that we don't need to check for - general storage here because that case corresponds to the scenario - where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ - if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \ - else rs_c_use *= 2; \ + /* Convert the strides and corresponding microtile dimension from being + in units of complex elements to be in units of real elements. */ \ + if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \ + else { rs_c_use *= 2; n_use *= 2; } \ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + m_use, \ + n_use, \ k, \ alpha_r, \ a_r, \ diff --git a/frame/3/gemm/bli_gemm_var.h b/frame/3/gemm/bli_gemm_var.h index 34cf95ae6d..888181bad6 100644 --- a/frame/3/gemm/bli_gemm_var.h +++ b/frame/3/gemm/bli_gemm_var.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,16 @@ */ +// +// gemm kernel parameter struct. +// + +typedef struct +{ + gemm_ukr_vft ukr; +} gemm_ker_params_t; + + // // Prototype object-based interfaces. // @@ -55,45 +65,7 @@ void PASTEMAC0(opname) \ GENPROT( gemm_blk_var1 ) GENPROT( gemm_blk_var2 ) GENPROT( gemm_blk_var3 ) -GENPROT( gemm_packa ) -GENPROT( gemm_packb ) GENPROT( gemm_ker_var1 ) - GENPROT( gemm_ker_var2 ) -// Headers for induced algorithms: -GENPROT( gemm4mb_ker_var2 ) // 4m1b - - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT_BASIC0( gemm_ker_var2 ) - -// Headers for induced algorithms: -INSERT_GENTPROT_BASIC0( gemm4mb_ker_var2 ) // 4m1b - diff --git a/frame/3/gemm/ind/bli_gemm4mb_ker_var2.c b/frame/3/gemm/ind/bli_gemm4mb_ker_var2.c deleted file mode 100644 index e4b377b37c..0000000000 --- a/frame/3/gemm/ind/bli_gemm4mb_ker_var2.c +++ /dev/null @@ -1,365 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T)( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,gemm4mb_ker_var2); - - -void bli_gemm4mb_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t ii; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ - dim_t jr_num_threads = bli_thread_n_way( thread ); \ - dim_t jr_thread_id = bli_thread_work_id( thread ); \ - dim_t ir_num_threads = bli_thread_n_way( caucus ); \ - dim_t ir_thread_id = bli_thread_work_id( caucus ); \ -\ - dim_t jr_inc = jr_num_threads; \ - dim_t ir_inc = ir_num_threads; \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ - \ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* In the 4mb method, we execute the ir loop twice: once for b_r - and once for b_i. */ \ - for ( ii = 0; ii < 2; ++ii ) \ - { \ - ctype* restrict beta_use; \ -\ - if ( ii == 0 ) \ - { \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_RO, &aux ); \ - beta_use = beta_cast; \ - } \ - else \ - { \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_IO, &aux ); \ - beta_use = one; \ - } \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_thread_id, ir_num_threads ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_thread_id, jr_num_threads ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ -/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3 (4m1b): c before", 8, 6, c11, rs_c, cs_c, "%4.1f", "" );*/ \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_use, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ -/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3 (4m1b): c after", 8, 6, c11, rs_c, cs_c, "%4.1f", "" );*/ \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -/*printf( "gemm_ker_var3 (4m1b): returning\n" );*/ \ -\ -/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3: a1", MR, k, a1, 1, MR, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNC_BASIC0( gemm4mb_ker_var2 ) - diff --git a/frame/3/gemm/ind/bli_gemm_ind_opt.h b/frame/3/gemm/ind/bli_gemm_ind_opt.h index 5c0d5153c7..52ea81a5e8 100644 --- a/frame/3/gemm/ind/bli_gemm_ind_opt.h +++ b/frame/3/gemm/ind/bli_gemm_ind_opt.h @@ -32,9 +32,10 @@ */ -static void bli_gemm_ind_recast_1m_params +BLIS_INLINE void bli_gemm_ind_recast_1m_params ( num_t* dt_exec, + num_t* dt_c, pack_t schema_a, obj_t* c, dim_t* m, @@ -57,6 +58,7 @@ static void bli_gemm_ind_recast_1m_params !bli_is_gen_stored( *rs_c, *cs_c ) ) { *dt_exec = bli_dt_proj_to_real( *dt_exec ); + *dt_c = bli_dt_proj_to_real( *dt_c ); if ( bli_is_1e_packed( schema_a ) ) { diff --git a/frame/3/gemm/ind/old/bli_gemm3m2_ker_var2.c b/frame/3/gemm/ind/old/bli_gemm3m2_ker_var2.c deleted file mode 100644 index 09830753eb..0000000000 --- a/frame/3/gemm/ind/old/bli_gemm3m2_ker_var2.c +++ /dev/null @@ -1,363 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T)( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,gemm3m2_ker_var2); - - -void bli_gemm3m2_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t ii; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ - dim_t jr_num_threads = bli_thread_n_way( thread ); \ - dim_t jr_thread_id = bli_thread_work_id( thread ); \ - dim_t ir_num_threads = bli_thread_n_way( caucus ); \ - dim_t ir_thread_id = bli_thread_work_id( caucus ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ - \ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* In the 3m2 method, we execute the ir loop thrice: once for - a_r[ir] * b_r, once for a_i[ir] * b_i, and once for - a_{r+i}[ir] * b_{r+i}. */ \ - for ( ii = 0; ii < 3; ++ii ) \ - { \ - ctype* restrict beta_use; \ -\ - if ( ii == 0 ) \ - { \ - bli_auxinfo_set_schema_a( BLIS_PACKED_ROW_PANELS_RO, &aux ); \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_RO, &aux ); \ - beta_use = beta_cast; \ - } \ - else if ( ii == 1 ) \ - { \ - bli_auxinfo_set_schema_a( BLIS_PACKED_ROW_PANELS_IO, &aux ); \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_IO, &aux ); \ - beta_use = one; \ - } \ - else \ - { \ - bli_auxinfo_set_schema_a( BLIS_PACKED_ROW_PANELS_RPI, &aux ); \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_RPI, &aux ); \ - beta_use = one; \ - } \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( caucus, a1, rstep_a ); \ - if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( thread, b1, cstep_b ); \ - if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_use, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -\ -/*PASTEMAC(ch,fprintm)( stdout, "gemm3m2_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm3m2_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNC_BASIC0( gemm3m2_ker_var2 ) - diff --git a/frame/3/gemm/other/bli_gemm_ker_var2.c b/frame/3/gemm/other/bli_gemm_ker_var2.c index 6ae8df0c1c..62d2a9e04b 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var2.c +++ b/frame/3/gemm/other/bli_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -198,7 +198,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/gemm/other/bli_gemm_ker_var2rr.c b/frame/3/gemm/other/bli_gemm_ker_var2rr.c index a213e50fc6..289e4ddf5b 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var2rr.c +++ b/frame/3/gemm/other/bli_gemm_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -199,7 +199,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/gemm/other/bli_gemm_ker_var2sl.c b/frame/3/gemm/other/bli_gemm_ker_var2sl.c index 0d710bd738..d75838fb4e 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var2sl.c +++ b/frame/3/gemm/other/bli_gemm_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -199,7 +199,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/gemm/other/bli_gemm_ker_var5.c b/frame/3/gemm/other/bli_gemm_ker_var5.c index 0d0c914d8a..9e13c4edd4 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var5.c +++ b/frame/3/gemm/other/bli_gemm_ker_var5.c @@ -45,7 +45,7 @@ typedef void (*FUNCPTR_T)( void* b, inc_t rs_b, dim_t pd_b, inc_t ps_b, void* beta, void* c, inc_t rs_c, inc_t cs_c, - void* gemm_ukr + void_fp gemm_ukr ); static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var5); @@ -87,7 +87,7 @@ void bli_gemm_ker_var5( obj_t* a, FUNCPTR_T f; func_t* gemm_ukrs; - void* gemm_ukr; + void_fp gemm_ukr; // Detach and multiply the scalars attached to A and B. @@ -135,7 +135,7 @@ void PASTEMAC(ch,varname)( \ void* b, inc_t rs_b, dim_t pd_b, inc_t ps_b, \ void* beta, \ void* c, inc_t rs_c, inc_t cs_c, \ - void* gemm_ukr \ + void_fp gemm_ukr \ ) \ { \ /* Cast the micro-kernel address to its function pointer type. */ \ diff --git a/frame/3/gemm/other/bli_gemm_ker_var5.h b/frame/3/gemm/other/bli_gemm_ker_var5.h index 7e24bb5f99..e88db5cb5a 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var5.h +++ b/frame/3/gemm/other/bli_gemm_ker_var5.h @@ -59,7 +59,7 @@ void PASTEMAC(ch,varname)( \ void* b, inc_t rs_b, dim_t pd_b, inc_t ps_b, \ void* beta, \ void* c, inc_t rs_c, inc_t cs_c, \ - void* gemm_ukr \ + void_fp gemm_ukr \ ); INSERT_GENTPROT_BASIC( gemm_ker_var5 ) diff --git a/frame/3/syrk/bli_syrk.h b/frame/3/gemmt/bli_gemmt.h similarity index 93% rename from frame/3/syrk/bli_syrk.h rename to frame/3/gemmt/bli_gemmt.h index 4936fe431e..32ab3865e7 100644 --- a/frame/3/syrk/bli_syrk.h +++ b/frame/3/gemmt/bli_gemmt.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,5 +32,7 @@ */ -#include "bli_syrk_front.h" +#include "bli_gemmt_front.h" + +#include "bli_gemmt_var.h" diff --git a/frame/3/syrk/bli_syrk_front.c b/frame/3/gemmt/bli_gemmt_front.c similarity index 67% rename from frame/3/syrk/bli_syrk_front.c rename to frame/3/gemmt/bli_gemmt_front.c index 534848e335..2a9d91759b 100644 --- a/frame/3/syrk/bli_syrk_front.c +++ b/frame/3/gemmt/bli_gemmt_front.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,10 +35,11 @@ #include "blis.h" -void bli_syrk_front +void bli_gemmt_front ( obj_t* alpha, obj_t* a, + obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, @@ -48,28 +50,37 @@ void bli_syrk_front bli_init_once(); obj_t a_local; - obj_t at_local; + obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_syrk_check( alpha, a, beta, c, cntx ); + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) { bli_scalm( beta, c ); return; } - // Alias A and C in case we need to apply transformations. + // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); - bli_obj_set_as_root( &c_local ); - // For syrk, the right-hand "B" operand is simply A^T. - bli_obj_alias_to( a, &at_local ); - bli_obj_induce_trans( &at_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); // An optimization: If C is stored by rows and the micro-kernel prefers // contiguous columns, or if C is stored by columns and the micro-kernel @@ -77,49 +88,37 @@ void bli_syrk_front // micro-kernel to access elements of C in its preferred manner. if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); bli_obj_induce_trans( &c_local ); } + // Set the pack schemas within the objects, as appropriate. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. bli_rntm_set_ways_for_op ( - BLIS_SYRK, - BLIS_LEFT, // ignored for her[2]k/syr[2]k + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm/gemmt bli_obj_length( &c_local ), bli_obj_width( &c_local ), bli_obj_width( &a_local ), rntm ); - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &at_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &at_local ); - } - - // Invoke the internal back-end. + // Invoke the internal back-end via the thread handler. bli_l3_thread_decorator ( - bli_gemm_int, - BLIS_HERK, // operation family id + bli_l3_int, + BLIS_GEMMT, // operation family id alpha, &a_local, - &at_local, + &b_local, beta, &c_local, cntx, diff --git a/frame/3/her2k/bli_her2k_front.h b/frame/3/gemmt/bli_gemmt_front.h similarity index 96% rename from frame/3/her2k/bli_her2k_front.h rename to frame/3/gemmt/bli_gemmt_front.h index 0efdb86c2d..c5967f8b8a 100644 --- a/frame/3/her2k/bli_her2k_front.h +++ b/frame/3/gemmt/bli_gemmt_front.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +33,7 @@ */ -void bli_her2k_front +void bli_gemmt_front ( obj_t* alpha, obj_t* a, diff --git a/frame/3/herk/bli_herk_l_ker_var2.c b/frame/3/gemmt/bli_gemmt_l_ker_var2.c similarity index 87% rename from frame/3/herk/bli_herk_l_ker_var2.c rename to frame/3/gemmt/bli_gemmt_l_ker_var2.c index d077b8f89f..fea4efec0a 100644 --- a/frame/3/herk/bli_herk_l_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_l_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" -#define FUNCPTR_T herk_fp +#define FUNCPTR_T gemmt_fp typedef void (*FUNCPTR_T) ( @@ -57,10 +57,10 @@ typedef void (*FUNCPTR_T) thrinfo_t* thread ); -static FUNCPTR_T GENARRAY(ftypes,herk_l_ker_var2); +static FUNCPTR_T GENARRAY(ftypes,gemmt_l_ker_var2); -void bli_herk_l_ker_var2 +void bli_gemmt_l_ker_var2 ( obj_t* a, obj_t* b, @@ -183,7 +183,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ @@ -362,11 +362,11 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ { \ a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ b2 = b_cast; \ } \ @@ -384,43 +384,20 @@ void PASTEMAC(ch,varname) \ And if we're strictly above the diagonal, we do nothing and continue. */ \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -467,11 +444,11 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ { \ a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ b2 = b_cast; \ } \ @@ -493,6 +470,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -512,47 +491,24 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ } -INSERT_GENTFUNC_BASIC0( herk_l_ker_var2 ) +INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2 ) diff --git a/frame/3/herk/bli_herk_u_ker_var2.c b/frame/3/gemmt/bli_gemmt_u_ker_var2.c similarity index 88% rename from frame/3/herk/bli_herk_u_ker_var2.c rename to frame/3/gemmt/bli_gemmt_u_ker_var2.c index b20a96df7e..4b849bbc6d 100644 --- a/frame/3/herk/bli_herk_u_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_u_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" -#define FUNCPTR_T herk_fp +#define FUNCPTR_T gemmt_fp typedef void (*FUNCPTR_T) ( @@ -57,10 +57,10 @@ typedef void (*FUNCPTR_T) thrinfo_t* thread ); -static FUNCPTR_T GENARRAY(ftypes,herk_u_ker_var2); +static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2); -void bli_herk_u_ker_var2 +void bli_gemmt_u_ker_var2 ( obj_t* a, obj_t* b, @@ -183,7 +183,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ @@ -362,11 +362,11 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ { \ a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ b2 = b_cast; \ } \ @@ -388,6 +388,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -407,43 +409,20 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -493,11 +472,11 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ { \ a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ b2 = b_cast; \ } \ @@ -515,47 +494,24 @@ void PASTEMAC(ch,varname) \ And if we're strictly below the diagonal, we do nothing and continue. */ \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ } -INSERT_GENTFUNC_BASIC0( herk_u_ker_var2 ) +INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2 ) diff --git a/frame/3/herk/bli_herk_var.h b/frame/3/gemmt/bli_gemmt_var.h similarity index 88% rename from frame/3/herk/bli_herk_var.h rename to frame/3/gemmt/bli_gemmt_var.h index 3c565e1b0f..60c68c9f59 100644 --- a/frame/3/herk/bli_herk_var.h +++ b/frame/3/gemmt/bli_gemmt_var.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,16 +52,10 @@ void PASTEMAC0(opname) \ thrinfo_t* thread \ ); -//GENPROT( herk_blk_var1 ) -//GENPROT( herk_blk_var2 ) -//GENPROT( herk_blk_var3 ) +GENPROT( gemmt_x_ker_var2 ) -GENPROT( herk_x_ker_var2 ) - -GENPROT( herk_l_ker_var2 ) -GENPROT( herk_u_ker_var2 ) -//GENPROT( herk_packa ) -//GENPROT( herk_packb ) +GENPROT( gemmt_l_ker_var2 ) +GENPROT( gemmt_u_ker_var2 ) // @@ -91,6 +85,6 @@ void PASTEMAC(ch,varname) \ thrinfo_t* thread \ ); -INSERT_GENTPROT_BASIC0( herk_l_ker_var2 ) -INSERT_GENTPROT_BASIC0( herk_u_ker_var2 ) +INSERT_GENTPROT_BASIC0( gemmt_l_ker_var2 ) +INSERT_GENTPROT_BASIC0( gemmt_u_ker_var2 ) diff --git a/frame/3/herk/bli_herk_x_ker_var2.c b/frame/3/gemmt/bli_gemmt_x_ker_var2.c similarity index 91% rename from frame/3/herk/bli_herk_x_ker_var2.c rename to frame/3/gemmt/bli_gemmt_x_ker_var2.c index c0ce23255c..3a1d681c3b 100644 --- a/frame/3/herk/bli_herk_x_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_x_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,12 +35,12 @@ #include "blis.h" -static gemm_var_oft vars[2] = +static l3_var_oft vars[2] = { - bli_herk_l_ker_var2, bli_herk_u_ker_var2, + bli_gemmt_l_ker_var2, bli_gemmt_u_ker_var2, }; -void bli_herk_x_ker_var2 +void bli_gemmt_x_ker_var2 ( obj_t* a, obj_t* ah, @@ -51,8 +51,8 @@ void bli_herk_x_ker_var2 thrinfo_t* thread ) { - bool_t uplo; - gemm_var_oft f; + dim_t uplo; + l3_var_oft f; // Set a bool based on the uplo field of C's root object. if ( bli_obj_root_is_lower( c ) ) uplo = 0; diff --git a/frame/3/herk/other/bli_herk_l_ker_var2.c b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c similarity index 96% rename from frame/3/herk/other/bli_herk_l_ker_var2.c rename to frame/3/gemmt/other/bli_gemmt_l_ker_var2.c index 904da9f5e2..0bf4b1a0fb 100644 --- a/frame/3/herk/other/bli_herk_l_ker_var2.c +++ b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" -#define FUNCPTR_T herk_fp +#define FUNCPTR_T gemmt_fp typedef void (*FUNCPTR_T) ( @@ -57,10 +57,10 @@ typedef void (*FUNCPTR_T) thrinfo_t* thread ); -static FUNCPTR_T GENARRAY(ftypes,herk_l_ker_var2); +static FUNCPTR_T GENARRAY(ftypes,gemmt_l_ker_var2); -void bli_herk_l_ker_var2 +void bli_gemmt_l_ker_var2 ( obj_t* a, obj_t* b, @@ -183,7 +183,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ @@ -318,11 +318,11 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( caucus, a1, rstep_a ); \ + a2 = bli_gemmt_get_next_a_upanel( caucus, a1, rstep_a ); \ if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ { \ a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( thread, b1, cstep_b ); \ + b2 = bli_gemmt_get_next_b_upanel( thread, b1, cstep_b ); \ if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ b2 = b_cast; \ } \ @@ -405,5 +405,5 @@ void PASTEMAC(ch,varname) \ } \ } -INSERT_GENTFUNC_BASIC0( herk_l_ker_var2 ) +INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2 ) diff --git a/frame/3/herk/other/bli_herk_u_ker_var2.c b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c similarity index 96% rename from frame/3/herk/other/bli_herk_u_ker_var2.c rename to frame/3/gemmt/other/bli_gemmt_u_ker_var2.c index 0bdc0b0a4c..1655bea555 100644 --- a/frame/3/herk/other/bli_herk_u_ker_var2.c +++ b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" -#define FUNCPTR_T herk_fp +#define FUNCPTR_T gemmt_fp typedef void (*FUNCPTR_T) ( @@ -57,10 +57,10 @@ typedef void (*FUNCPTR_T) thrinfo_t* thread ); -static FUNCPTR_T GENARRAY(ftypes,herk_u_ker_var2); +static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2); -void bli_herk_u_ker_var2 +void bli_gemmt_u_ker_var2 ( obj_t* a, obj_t* b, @@ -183,7 +183,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ @@ -318,11 +318,11 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( caucus, a1, rstep_a ); \ + a2 = bli_gemmt_get_next_a_upanel( caucus, a1, rstep_a ); \ if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ { \ a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( thread, b1, cstep_b ); \ + b2 = bli_gemmt_get_next_b_upanel( thread, b1, cstep_b ); \ if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ b2 = b_cast; \ } \ @@ -405,5 +405,5 @@ void PASTEMAC(ch,varname) \ } \ } -INSERT_GENTFUNC_BASIC0( herk_u_ker_var2 ) +INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2 ) diff --git a/frame/3/hemm/bli_hemm_front.c b/frame/3/hemm/bli_hemm_front.c index 5949a2e6c5..9835de9c15 100644 --- a/frame/3/hemm/bli_hemm_front.c +++ b/frame/3/hemm/bli_hemm_front.c @@ -53,10 +53,6 @@ void bli_hemm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_hemm_check( side, alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -69,10 +65,58 @@ void bli_hemm_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); + +#ifdef BLIS_DISABLE_HEMM_RIGHT + // NOTE: This case casts right-side hemm in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from Hermitian/ + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the Hermitian + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_HEMM_RIGHT. + + // NOTE: This case casts right-side hemm in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // If A is being multiplied from the right, transpose all operands + // so that we can perform the computation as if A were being multiplied + // from the left. + if ( bli_is_right( side ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + +#else + // NOTE: This case computes right-side hemm/symm natively by packing + // elements of the Hermitian/symmetric matrix A to micropanels of the + // right-hand packed matrix operand "B", and elements of the general + // matrix B to micropanels of the left-hand packed matrix operand "A". + // This code path always gives us the opportunity to transpose the + // entire operation so that the effective storage format of the output + // matrix matches the microkernel's output preference. Thus, from a + // performance perspective, this case is preferred. + // An optimization: If C is stored by rows and the micro-kernel prefers // contiguous columns, or if C is stored by columns and the micro-kernel // prefers contiguous rows, transpose the entire operation to allow the // micro-kernel to access elements of C in its preferred manner. + //if ( !bli_obj_is_1x1( &c_local ) ) // NOTE: This conditional should NOT + // be enabled. See issue #342 comments. if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) { bli_toggle_side( &side ); @@ -81,12 +125,17 @@ void bli_hemm_front bli_obj_induce_trans( &c_local ); } - // Swap A and B if multiplying A from the right so that "B" contains - // the Hermitian matrix. + // If the Hermitian/symmetric matrix A is being multiplied from the right, + // swap A and B so that the Hermitian/symmetric matrix will actually be on + // the right. if ( bli_is_right( side ) ) { bli_obj_swap( &a_local, &b_local ); } +#endif + + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any @@ -101,29 +150,10 @@ void bli_hemm_front rntm ); - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - // Invoke the internal back-end. bli_l3_thread_decorator ( - bli_gemm_int, + bli_l3_int, BLIS_GEMM, // operation family id alpha, &a_local, diff --git a/frame/3/her2k/bli_her2k_front.c b/frame/3/her2k/bli_her2k_front.c deleted file mode 100644 index a99aa05c8d..0000000000 --- a/frame/3/her2k/bli_her2k_front.c +++ /dev/null @@ -1,184 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -void bli_her2k_front - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t alpha_conj; - obj_t c_local; - obj_t a_local; - obj_t bh_local; - obj_t b_local; - obj_t ah_local; - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_her2k_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta, zero the imaginary components of - // the diagonal elements, and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - bli_setid( &BLIS_ZERO, c ); - return; - } - - // Alias A, B, and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - bli_obj_set_as_root( &c_local ); - - // For her2k, the first and second right-hand "B" operands are simply B' - // and A'. - bli_obj_alias_to( b, &bh_local ); - bli_obj_induce_trans( &bh_local ); - bli_obj_toggle_conj( &bh_local ); - bli_obj_alias_to( a, &ah_local ); - bli_obj_induce_trans( &ah_local ); - bli_obj_toggle_conj( &ah_local ); - - // Initialize a conjugated copy of alpha. - bli_obj_scalar_init_detached_copy_of( bli_obj_dt( a ), - BLIS_CONJUGATE, - alpha, - &alpha_conj ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_swap( &a_local, &bh_local ); - bli_obj_swap( &b_local, &ah_local ); - - bli_obj_induce_trans( &a_local ); - bli_obj_induce_trans( &bh_local ); - bli_obj_induce_trans( &b_local ); - bli_obj_induce_trans( &ah_local ); - - bli_obj_induce_trans( &c_local ); - } - - // Parse and interpret the contents of the rntm_t object to properly - // set the ways of parallelism for each loop, and then make any - // additional modifications necessary for the current operation. - bli_rntm_set_ways_for_op - ( - BLIS_HER2K, - BLIS_LEFT, // ignored for her[2]k/syr[2]k - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &bh_local ); - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &b_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &ah_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &bh_local ); - bli_obj_set_pack_schema( schema_a, &b_local ); - bli_obj_set_pack_schema( schema_b, &ah_local ); - } - - // Invoke herk twice, using beta only the first time. - - // Invoke the internal back-end. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_HERK, // operation family id - alpha, - &a_local, - &bh_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); - - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_HERK, // operation family id - &alpha_conj, - &b_local, - &ah_local, - &BLIS_ONE, - &c_local, - cntx, - rntm, - cntl - ); - - // The Hermitian rank-2k product was computed as A*B'+B*A', even for - // the diagonal elements. Mathematically, the imaginary components of - // diagonal elements of a Hermitian rank-2k product should always be - // zero. However, in practice, they sometimes accumulate meaningless - // non-zero values. To prevent this, we explicitly set those values - // to zero before returning. - bli_setid( &BLIS_ZERO, &c_local ); -} - diff --git a/frame/3/herk/bli_herk_front.c b/frame/3/herk/bli_herk_front.c deleted file mode 100644 index be0118f18e..0000000000 --- a/frame/3/herk/bli_herk_front.c +++ /dev/null @@ -1,144 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -void bli_herk_front - ( - obj_t* alpha, - obj_t* a, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t a_local; - obj_t ah_local; - obj_t c_local; - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_herk_check( alpha, a, beta, c, cntx ); - - // If alpha is zero, scale by beta, zero the imaginary components of - // the diagonal elements, and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - bli_setid( &BLIS_ZERO, c ); - return; - } - - // Alias A and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( c, &c_local ); - bli_obj_set_as_root( &c_local ); - - // For herk, the right-hand "B" operand is simply A'. - bli_obj_alias_to( a, &ah_local ); - bli_obj_induce_trans( &ah_local ); - bli_obj_toggle_conj( &ah_local ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_toggle_conj( &a_local ); - bli_obj_toggle_conj( &ah_local ); - - bli_obj_induce_trans( &c_local ); - } - - // Parse and interpret the contents of the rntm_t object to properly - // set the ways of parallelism for each loop, and then make any - // additional modifications necessary for the current operation. - bli_rntm_set_ways_for_op - ( - BLIS_HERK, - BLIS_LEFT, // ignored for her[2]k/syr[2]k - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &ah_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &ah_local ); - } - - // Invoke the internal back-end. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_HERK, // operation family id - alpha, - &a_local, - &ah_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); - - // The Hermitian rank-k product was computed as A*A', even for the - // diagonal elements. Mathematically, the imaginary components of - // diagonal elements of a Hermitian rank-k product should always be - // zero. However, in practice, they sometimes accumulate meaningless - // non-zero values. To prevent this, we explicitly set those values - // to zero before returning. - bli_setid( &BLIS_ZERO, &c_local ); -} - diff --git a/frame/3/herk/other/bli_herk_l_ker_var2.1looprr.c b/frame/3/herk/other/bli_herk_l_ker_var2.1looprr.c deleted file mode 100644 index 38675b11b8..0000000000 --- a/frame/3/herk/other/bli_herk_l_ker_var2.1looprr.c +++ /dev/null @@ -1,420 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T herk_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,herk_l_ker_var2); - - -void bli_herk_l_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, ip; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely above the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region above where the diagonal of C intersects - the left edge of the panel, adjust the pointer to C and A and treat - this case as if the diagonal offset were zero. */ \ - if ( diagoffc < 0 ) \ - { \ - ip = -diagoffc / MR; \ - i = ip * MR; \ - m = m - i; \ - diagoffc = -diagoffc % MR; \ - c_cast = c_cast + (i )*rs_c; \ - a_cast = a_cast + (ip )*ps_a; \ - } \ -\ - /* If there is a zero region to the right of where the diagonal - of C intersects the bottom of the panel, shrink it to prevent - "no-op" iterations from executing. */ \ - if ( diagoffc + m < n ) \ - { \ - n = diagoffc + m; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Use interleaved (round robin) assignment of micropanels to threads in - the 2nd and 1st loops. */ \ - bli_thread_range_jrir_rr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly below the diagonal, - we compute and store as we normally would. - And if we're strictly above the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( herk_l_ker_var2 ) - diff --git a/frame/3/herk/other/bli_herk_l_ker_var2rr.c b/frame/3/herk/other/bli_herk_l_ker_var2rr.c deleted file mode 100644 index a313f04b21..0000000000 --- a/frame/3/herk/other/bli_herk_l_ker_var2rr.c +++ /dev/null @@ -1,555 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T herk_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,herk_l_ker_var2rr); - -// -// -- Macrokernel functions for round-robin partitioning ----------------------- -// - -void bli_herk_l_ker_var2rr - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, ip; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely above the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region above where the diagonal of C intersects - the left edge of the panel, adjust the pointer to C and A and treat - this case as if the diagonal offset were zero. */ \ - if ( diagoffc < 0 ) \ - { \ - ip = -diagoffc / MR; \ - i = ip * MR; \ - m = m - i; \ - diagoffc = -diagoffc % MR; \ - c_cast = c_cast + (i )*rs_c; \ - a_cast = a_cast + (ip )*ps_a; \ - } \ -\ - /* If there is a zero region to the right of where the diagonal - of C intersects the bottom of the panel, shrink it to prevent - "no-op" iterations from executing. */ \ - if ( diagoffc + m < n ) \ - { \ - n = diagoffc + m; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Note that we partition the 2nd loop into two regions: the rectangular - part of C, and the triangular portion. */ \ - dim_t n_iter_rct; \ - dim_t n_iter_tri; \ -\ - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) \ - { \ - /* If the entire panel of C does not intersect the diagonal, there is - no triangular region, and therefore we can skip the second set of - loops. */ \ - n_iter_rct = n_iter; \ - n_iter_tri = 0; \ - } \ - else \ - { \ - /* If the panel of C does intersect the diagonal, compute the number of - iterations in the rectangular region by dividing NR into the diagonal - offset. Any remainder from this integer division is discarded, which - is what we want. That is, we want the rectangular region to contain - as many columns of whole microtiles as possible without including any - microtiles that intersect the diagonal. The number of iterations in - the triangular (or trapezoidal) region is computed as the remaining - number of iterations in the n dimension. */ \ - n_iter_rct = diagoffc / NR; \ - n_iter_tri = n_iter - n_iter_rct; \ - } \ -\ - /* Use round-robin assignment of micropanels to threads in the 2nd and 1st - loops for the initial rectangular region of C (if it exists). */ \ - bli_thread_range_jrir_rr( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* No need to compute the diagonal offset for the rectangular - region. */ \ - /*diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly below the diagonal, - we compute and store as we normally would. - And if we're strictly above the diagonal, we do nothing and - continue. */ \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -\ - /* If there is no triangular region, then we're done. */ \ - if ( n_iter_tri == 0 ) return; \ -\ - /* Use round-robin assignment of micropanels to threads in the 2nd and - 1st loops for the remaining triangular region of C. */ \ - bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ -\ - /* Advance the start and end iteration offsets for the triangular region - by the number of iterations used for the rectangular region. */ \ - jr_start += n_iter_rct; \ - jr_end += n_iter_rct; \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly below the diagonal, - we compute and store as we normally would. - And if we're strictly above the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( herk_l_ker_var2rr ) - diff --git a/frame/3/herk/other/bli_herk_l_ker_var2sl.c b/frame/3/herk/other/bli_herk_l_ker_var2sl.c deleted file mode 100644 index f913cced29..0000000000 --- a/frame/3/herk/other/bli_herk_l_ker_var2sl.c +++ /dev/null @@ -1,556 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T herk_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,herk_l_ker_var2sl); - -// -// -- Macrokernel functions for slab partitioning ------------------------------ -// - -void bli_herk_l_ker_var2sl - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, ip; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely above the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region above where the diagonal of C intersects - the left edge of the panel, adjust the pointer to C and A and treat - this case as if the diagonal offset were zero. */ \ - if ( diagoffc < 0 ) \ - { \ - ip = -diagoffc / MR; \ - i = ip * MR; \ - m = m - i; \ - diagoffc = -diagoffc % MR; \ - c_cast = c_cast + (i )*rs_c; \ - a_cast = a_cast + (ip )*ps_a; \ - } \ -\ - /* If there is a zero region to the right of where the diagonal - of C intersects the bottom of the panel, shrink it to prevent - "no-op" iterations from executing. */ \ - if ( diagoffc + m < n ) \ - { \ - n = diagoffc + m; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Note that we partition the 2nd loop into two regions: the rectangular - part of C, and the triangular portion. */ \ - dim_t n_iter_rct; \ - dim_t n_iter_tri; \ -\ - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) \ - { \ - /* If the entire panel of C does not intersect the diagonal, there is - no triangular region, and therefore we can skip the second set of - loops. */ \ - n_iter_rct = n_iter; \ - n_iter_tri = 0; \ - } \ - else \ - { \ - /* If the panel of C does intersect the diagonal, compute the number of - iterations in the rectangular region by dividing NR into the diagonal - offset. Any remainder from this integer division is discarded, which - is what we want. That is, we want the rectangular region to contain - as many columns of whole microtiles as possible without including any - microtiles that intersect the diagonal. The number of iterations in - the triangular (or trapezoidal) region is computed as the remaining - number of iterations in the n dimension. */ \ - n_iter_rct = diagoffc / NR; \ - n_iter_tri = n_iter - n_iter_rct; \ - } \ -\ - /* Use slab assignment of micropanels to threads in the 2nd and 1st - loops for the initial rectangular region of C (if it exists). */ \ - bli_thread_range_jrir_sl( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir_sl( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* No need to compute the diagonal offset for the rectangular - region. */ \ - /*diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_sl( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_sl( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly below the diagonal, - we compute and store as we normally would. - And if we're strictly above the diagonal, we do nothing and - continue. */ \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -\ - /* If there is no triangular region, then we're done. */ \ - if ( n_iter_tri == 0 ) return; \ -\ - /* Use round-robin assignment of micropanels to threads in the 2nd - loop and slab partitioning in the 1st loop for the remaining - triangular region of C. */ \ - bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ -\ - /* Advance the start and end iteration offsets for the triangular region - by the number of iterations used for the rectangular region. */ \ - jr_start += n_iter_rct; \ - jr_end += n_iter_rct; \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly below the diagonal, - we compute and store as we normally would. - And if we're strictly above the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( herk_l_ker_var2sl ) - diff --git a/frame/3/herk/other/bli_herk_u_ker_var2.1looprr.c b/frame/3/herk/other/bli_herk_u_ker_var2.1looprr.c deleted file mode 100644 index cd4a4e7ade..0000000000 --- a/frame/3/herk/other/bli_herk_u_ker_var2.1looprr.c +++ /dev/null @@ -1,420 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T herk_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,herk_u_ker_var2); - - -void bli_herk_u_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, jp; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely below the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region to the left of where the diagonal of C - intersects the top edge of the panel, adjust the pointer to C and B - and treat this case as if the diagonal offset were zero. */ \ - if ( diagoffc > 0 ) \ - { \ - jp = diagoffc / NR; \ - j = jp * NR; \ - n = n - j; \ - diagoffc = diagoffc % NR; \ - c_cast = c_cast + (j )*cs_c; \ - b_cast = b_cast + (jp )*ps_b; \ - } \ -\ - /* If there is a zero region below where the diagonal of C intersects - the right edge of the panel, shrink it to prevent "no-op" iterations - from executing. */ \ - if ( -diagoffc + n < m ) \ - { \ - m = -diagoffc + n; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Use interleaved (round robin) assignment of micropanels to threads in - the 2nd and 1st loops. */ \ - bli_thread_range_jrir_rr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly above the diagonal, - we compute and store as we normally would. - And if we're strictly below the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( herk_u_ker_var2 ) - diff --git a/frame/3/herk/other/bli_herk_u_ker_var2rr.c b/frame/3/herk/other/bli_herk_u_ker_var2rr.c deleted file mode 100644 index 4ffa8085c7..0000000000 --- a/frame/3/herk/other/bli_herk_u_ker_var2rr.c +++ /dev/null @@ -1,557 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T herk_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,herk_u_ker_var2rr); - -// -// -- Macrokernel functions for round-robin partitioning ----------------------- -// - -void bli_herk_u_ker_var2rr - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, jp; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely below the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region to the left of where the diagonal of C - intersects the top edge of the panel, adjust the pointer to C and B - and treat this case as if the diagonal offset were zero. - NOTE: It's possible that after this pruning that the diagonal offset - is still positive (though it is guaranteed to be less than NR). */ \ - if ( diagoffc > 0 ) \ - { \ - jp = diagoffc / NR; \ - j = jp * NR; \ - n = n - j; \ - diagoffc = diagoffc % NR; \ - c_cast = c_cast + (j )*cs_c; \ - b_cast = b_cast + (jp )*ps_b; \ - } \ -\ - /* If there is a zero region below where the diagonal of C intersects - the right edge of the panel, shrink it to prevent "no-op" iterations - from executing. */ \ - if ( -diagoffc + n < m ) \ - { \ - m = -diagoffc + n; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Note that we partition the 2nd loop into two regions: the triangular - part of C, and the rectangular portion. */ \ - dim_t n_iter_tri; \ - dim_t n_iter_rct; \ -\ - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) \ - { \ - /* If the entire panel of C does not intersect the diagonal, there is - no triangular region, and therefore we can skip the first set of - loops. */ \ - n_iter_tri = 0; \ - n_iter_rct = n_iter; \ - } \ - else \ - { \ - /* If the panel of C does intersect the diagonal, compute the number of - iterations in the triangular (or trapezoidal) region by dividing NR - into the number of rows in C. A non-zero remainder means we need to - add one additional iteration. That is, we want the triangular region - to contain as few columns of whole microtiles as possible while still - including all microtiles that intersect the diagonal. The number of - iterations in the rectangular region is computed as the remaining - number of iterations in the n dimension. */ \ - n_iter_tri = ( m + diagoffc ) / NR + ( ( m + diagoffc ) % NR ? 1 : 0 ); \ - n_iter_rct = n_iter - n_iter_tri; \ - } \ -\ - /* Use round-robin assignment of micropanels to threads in the 2nd and 1st - loops for the initial triangular region of C (if it exists). */ \ - bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly above the diagonal, - we compute and store as we normally would. - And if we're strictly below the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -\ - /* If there is no rectangular region, then we're done. */ \ - if ( n_iter_rct == 0 ) return; \ -\ - /* Use round-robin assignment of micropanels to threads in the 2nd and 1st - loops for the remaining triangular region of C. */ \ - bli_thread_range_jrir_rr( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ -\ - /* Advance the start and end iteration offsets for the rectangular region - by the number of iterations used for the triangular region. */ \ - jr_start += n_iter_tri; \ - jr_end += n_iter_tri; \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* No need to compute the diagonal offset for the rectangular - region. */ \ - /*diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly above the diagonal, - we compute and store as we normally would. - And if we're strictly below the diagonal, we do nothing and - continue. */ \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( herk_u_ker_var2rr ) - diff --git a/frame/3/herk/other/bli_herk_u_ker_var2sl.c b/frame/3/herk/other/bli_herk_u_ker_var2sl.c deleted file mode 100644 index 7af7ee56de..0000000000 --- a/frame/3/herk/other/bli_herk_u_ker_var2sl.c +++ /dev/null @@ -1,558 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T herk_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,herk_u_ker_var2sl); - -// -// -- Macrokernel functions for slab partitioning ------------------------------ -// - -void bli_herk_u_ker_var2sl - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, jp; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely below the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region to the left of where the diagonal of C - intersects the top edge of the panel, adjust the pointer to C and B - and treat this case as if the diagonal offset were zero. - NOTE: It's possible that after this pruning that the diagonal offset - is still positive (though it is guaranteed to be less than NR). */ \ - if ( diagoffc > 0 ) \ - { \ - jp = diagoffc / NR; \ - j = jp * NR; \ - n = n - j; \ - diagoffc = diagoffc % NR; \ - c_cast = c_cast + (j )*cs_c; \ - b_cast = b_cast + (jp )*ps_b; \ - } \ -\ - /* If there is a zero region below where the diagonal of C intersects - the right edge of the panel, shrink it to prevent "no-op" iterations - from executing. */ \ - if ( -diagoffc + n < m ) \ - { \ - m = -diagoffc + n; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Note that we partition the 2nd loop into two regions: the triangular - part of C, and the rectangular portion. */ \ - dim_t n_iter_tri; \ - dim_t n_iter_rct; \ -\ - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) \ - { \ - /* If the entire panel of C does not intersect the diagonal, there is - no triangular region, and therefore we can skip the first set of - loops. */ \ - n_iter_tri = 0; \ - n_iter_rct = n_iter; \ - } \ - else \ - { \ - /* If the panel of C does intersect the diagonal, compute the number of - iterations in the triangular (or trapezoidal) region by dividing NR - into the number of rows in C. A non-zero remainder means we need to - add one additional iteration. That is, we want the triangular region - to contain as few columns of whole microtiles as possible while still - including all microtiles that intersect the diagonal. The number of - iterations in the rectangular region is computed as the remaining - number of iterations in the n dimension. */ \ - n_iter_tri = ( m + diagoffc ) / NR + ( ( m + diagoffc ) % NR ? 1 : 0 ); \ - n_iter_rct = n_iter - n_iter_tri; \ - } \ -\ - /* Use round-robin assignment of micropanels to threads in the 2nd loop - and slab partitioning in the 1st loop for the initial triangular region - of C (if it exists). */ \ - bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir_sl( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_sl( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly above the diagonal, - we compute and store as we normally would. - And if we're strictly below the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -\ - /* If there is no rectangular region, then we're done. */ \ - if ( n_iter_rct == 0 ) return; \ -\ - /* Use slab assignment of micropanels to threads in the 2nd and 1st loops - loop for the remaining triangular region of C. */ \ - bli_thread_range_jrir_sl( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ -\ - /* Advance the start and end iteration offsets for the rectangular region - by the number of iterations used for the triangular region. */ \ - jr_start += n_iter_tri; \ - jr_end += n_iter_tri; \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* No need to compute the diagonal offset for the rectangular - region. */ \ - /*diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_herk_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_sl( i, m_iter, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_herk_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_sl( j, n_iter, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly above the diagonal, - we compute and store as we normally would. - And if we're strictly below the diagonal, we do nothing and - continue. */ \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( herk_u_ker_var2sl ) - diff --git a/frame/3/bli_l3_ft_ex.h b/frame/3/old/bli_l3_ft_ex.h similarity index 100% rename from frame/3/bli_l3_ft_ex.h rename to frame/3/old/bli_l3_ft_ex.h diff --git a/frame/3/old/bli_l3_sup_edge.h b/frame/3/old/bli_l3_sup_edge.h new file mode 100644 index 0000000000..06f3bb18b1 --- /dev/null +++ b/frame/3/old/bli_l3_sup_edge.h @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +static +void bli_dgemmsup_ker_edge_dispatcher + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx, + const dim_t num_mr, + const dim_t num_nr, + dim_t* restrict mrs, + dim_t* restrict nrs, + dgemmsup_ker_ft* kmap + ) +{ + #if 1 + + // outer loop = mr; inner loop = nr + + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + if ( mr_cur <= m_left ) + { + dgemmsup_ker_ft ker_fp = kmap[ i*num_nr + j*1 ]; + + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + #else + + // outer loop = nr; inner loop = mr + + dim_t m_left = m0; + double* restrict ci = c; + double* restrict ai = a; + + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + if ( mr_cur <= m_left ) + { + dim_t n_left = n0; + double* restrict cij = ci; + double* restrict bj = b; + + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + if ( nr_cur <= n_left ) + { + dgemmsup_ker_ft ker_fp = kmap[ i*num_nr + j*1 ]; + + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + + } + } + + ci += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + #endif +} + diff --git a/frame/3/old/bli_l3_sup_var1n2m.c b/frame/3/old/bli_l3_sup_var1n2m.c new file mode 100644 index 0000000000..a1cfbbb241 --- /dev/null +++ b/frame/3/old/bli_l3_sup_var1n2m.c @@ -0,0 +1,821 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmsup_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t eff_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + cntl_t* restrict cntl, + thrinfo_t* restrict thread + ); + +// +// -- var1n -------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var1n,gemmsup_ref_var1n); + +void bli_gemmsup_ref_var1n + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var1n[dt_exec]; + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm, + cntl, + thread + ); + } + else + { + // Invoke the function (transposing the operation). + f + ( + conjb, // swap the conj values. + conja, + n, // swap the m and n dimensions. + m, + k, + buf_alpha, + buf_b, cs_b, rs_b, // swap the positions of A and B. + buf_a, cs_a, rs_a, // swap the strides of A and B. + buf_beta, + buf_c, cs_c, rs_c, // swap the strides of C. + bli_stor3_trans( eff_id ), // transpose the stor3_t id. + cntx, + rntm, + cntl, + thread + ); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + cntl_t* restrict cntl, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* This transposition of the stor3_t id value is inherent to variant 1. + The reason: we assume that variant 2 is the "main" variant. The + consequence of this is that we assume that the millikernels that + iterate over m are registered to the kernel group associated with + the kernel preference. So, regardless of whether the mkernels are + row- or column-preferential, millikernels that iterate over n are + always placed in the slots for the opposite kernel group. */ \ + stor_id = bli_stor3_trans( stor_id ); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + dim_t KC; \ + if ( FALSE ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( m <= MR && n <= NR ) KC = KC0; \ + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ + else KC = (( KC0 / 5 ) / 4 ) * 4; \ +\ + /* Nudge NC up to a multiple of MR and MC up to a multiple of NR. */ \ + const dim_t NC = bli_align_dim_to_mult( NC0, MR ); \ + const dim_t MC = bli_align_dim_to_mult( MC0, NR ); \ +\ + /* Query the maximum blocksize for MR, which implies a maximum blocksize + extension for the final iteration. */ \ + const dim_t MRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_MR, cntx ); \ + const dim_t MRE = MRM - MR; \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = rs_c * NC; \ + const inc_t jcstep_a = rs_a * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = cs_c * MC; \ + const inc_t icstep_b = cs_b * MC; \ +\ + const inc_t jrstep_c = rs_c * MR; \ + const inc_t jrstep_a = rs_a * MR; \ +\ + /* + const inc_t irstep_c = cs_c * NR; \ + const inc_t irstep_b = cs_b * NR; \ + */ \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = m / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( m + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( m + NC - 1 ) / NC; \ + const dim_t jc_left = m % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( n + MC - 1 ) / MC; \ + const dim_t ic_left = n % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + /* + const dim_t ir_inc = 1; \ + */ \ +\ + /* Loop over the m dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict a_jc = a_00 + jj * jcstep_a; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + dim_t jr_iter = ( nc_cur + MR - 1 ) / MR; \ + dim_t jr_left = nc_cur % MR; \ +\ + /* An optimization: allow the last jr iteration to contain up to MRE + rows of C and A. (If MRE > MR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. */ \ + if ( 1 ) \ + if ( MRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= MRE ) \ + { \ + jr_iter--; jr_left += MR; \ + } \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_jc + pp * pcstep_a; \ + ctype* restrict b_pc = b_00 + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the n dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict b_ic = b_pc + ii * icstep_b; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + /* + const dim_t ir_iter = ( mc_cur + NR - 1 ) / NR; \ + const dim_t ir_left = mc_cur % NR; \ + */ \ +\ + /* Loop over the m dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? MR : jr_left ); \ +\ + ctype* restrict a_jr = a_pc + j * jrstep_a; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Loop over the n dimension (MR rows at a time). */ \ + { \ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + nr_cur, /* Notice: nr_cur <= MR. */ \ + mc_cur, /* Recall: mc_cur partitions the n dimension! */ \ + kc_cur, \ + alpha_cast, \ + a_jr, rs_a, cs_a, \ + b_ic, rs_b, cs_b, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var1n ) + + +// +// -- var2m -------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var2m,gemmsup_ref_var2m); + +void bli_gemmsup_ref_var2m + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var2m[dt_exec]; + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm, + cntl, + thread + ); + } + else + { + // Invoke the function (transposing the operation). + f + ( + conjb, // swap the conj values. + conja, + n, // swap the m and n dimensions. + m, + k, + buf_alpha, + buf_b, cs_b, rs_b, // swap the positions of A and B. + buf_a, cs_a, rs_a, // swap the strides of A and B. + buf_beta, + buf_c, cs_c, rs_c, // swap the strides of C. + bli_stor3_trans( eff_id ), // transpose the stor3_t id. + cntx, + rntm, + cntl, + thread + ); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + cntl_t* restrict cntl, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + dim_t KC; \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( m <= MR && n <= NR ) KC = KC0; \ + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ + else KC = (( KC0 / 5 ) / 4 ) * 4; \ +\ + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ \ + const dim_t NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); \ + const dim_t NRE = NRM - NR; \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c * NC; \ + const inc_t jcstep_b = cs_b * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = rs_c * MC; \ + const inc_t icstep_a = rs_a * MC; \ +\ + const inc_t jrstep_c = cs_c * NR; \ + const inc_t jrstep_b = cs_b * NR; \ +\ + /* + const inc_t irstep_c = rs_c * MR; \ + const inc_t irstep_a = rs_a * MR; \ + */ \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = n / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( n + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( n + NC - 1 ) / NC; \ + const dim_t jc_left = n % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( m + MC - 1 ) / MC; \ + const dim_t ic_left = m % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + /* + const dim_t ir_inc = 1; \ + */ \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* An optimization: allow the last jr iteration to contain up to NRE + columns of C and B. (If NRE > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. */ \ + if ( 1 ) \ + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) \ + { \ + jr_iter--; jr_left += NR; \ + } \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + /* + const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + const dim_t ir_left = mc_cur % MR; \ + */ \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc + j * jrstep_b; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + { \ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mc_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ic, rs_a, cs_a, \ + b_jr, rs_b, cs_b, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var2m ) + diff --git a/frame/3/symm/bli_symm_front.c b/frame/3/symm/bli_symm_front.c index 820c26fd16..be94c44c1b 100644 --- a/frame/3/symm/bli_symm_front.c +++ b/frame/3/symm/bli_symm_front.c @@ -53,10 +53,6 @@ void bli_symm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_symm_check( side, alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -69,10 +65,58 @@ void bli_symm_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); + +#ifdef BLIS_DISABLE_SYMM_RIGHT + // NOTE: This case casts right-side symm in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from symmetric + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the symmetric + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_SYMM_RIGHT. + + // NOTE: This case casts right-side symm in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // If A is being multiplied from the right, transpose all operands + // so that we can perform the computation as if A were being multiplied + // from the left. + if ( bli_is_right( side ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + +#else + // NOTE: This case computes right-side hemm/symm natively by packing + // elements of the Hermitian/symmetric matrix A to micropanels of the + // right-hand packed matrix operand "B", and elements of the general + // matrix B to micropanels of the left-hand packed matrix operand "A". + // This code path always gives us the opportunity to transpose the + // entire operation so that the effective storage format of the output + // matrix matches the microkernel's output preference. Thus, from a + // performance perspective, this case is preferred. + // An optimization: If C is stored by rows and the micro-kernel prefers // contiguous columns, or if C is stored by columns and the micro-kernel // prefers contiguous rows, transpose the entire operation to allow the // micro-kernel to access elements of C in its preferred manner. + //if ( !bli_obj_is_1x1( &c_local ) ) // NOTE: This conditional should NOT + // be enabled. See issue #342 comments. if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) { bli_toggle_side( &side ); @@ -80,12 +124,17 @@ void bli_symm_front bli_obj_induce_trans( &c_local ); } - // Swap A and B if multiplying A from the right so that "B" contains - // the symmetric matrix. + // If the Hermitian/symmetric matrix A is being multiplied from the right, + // swap A and B so that the Hermitian/symmetric matrix will actually be on + // the right. if ( bli_is_right( side ) ) { bli_obj_swap( &a_local, &b_local ); } +#endif + + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any @@ -100,29 +149,10 @@ void bli_symm_front rntm ); - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - // Invoke the internal back-end. bli_l3_thread_decorator ( - bli_gemm_int, + bli_l3_int, BLIS_GEMM, // operation family id alpha, &a_local, diff --git a/frame/3/syr2k/bli_syr2k_front.c b/frame/3/syr2k/bli_syr2k_front.c deleted file mode 100644 index 3ccd28c5c2..0000000000 --- a/frame/3/syr2k/bli_syr2k_front.c +++ /dev/null @@ -1,157 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -void bli_syr2k_front - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t c_local; - obj_t a_local; - obj_t bt_local; - obj_t b_local; - obj_t at_local; - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_syr2k_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - return; - } - - // Alias A, B, and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - bli_obj_set_as_root( &c_local ); - - // For syr2k, the first and second right-hand "B" operands are simply B' - // and A'. - bli_obj_alias_to( b, &bt_local ); - bli_obj_induce_trans( &bt_local ); - bli_obj_alias_to( a, &at_local ); - bli_obj_induce_trans( &at_local ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_induce_trans( &c_local ); - } - - // Parse and interpret the contents of the rntm_t object to properly - // set the ways of parallelism for each loop, and then make any - // additional modifications necessary for the current operation. - bli_rntm_set_ways_for_op - ( - BLIS_SYR2K, - BLIS_LEFT, // ignored for her[2]k/syr[2]k - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &bt_local ); - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &b_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &at_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &bt_local ); - bli_obj_set_pack_schema( schema_a, &b_local ); - bli_obj_set_pack_schema( schema_b, &at_local ); - } - - // Invoke herk twice, using beta only the first time. - - // Invoke the internal back-end. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_HERK, // operation family id - alpha, - &a_local, - &bt_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); - - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_HERK, // operation family id - alpha, - &b_local, - &at_local, - &BLIS_ONE, - &c_local, - cntx, - rntm, - cntl - ); -} - diff --git a/frame/3/trmm/bli_trmm_front.c b/frame/3/trmm/bli_trmm_front.c index aee9d1d6f1..1de28958eb 100644 --- a/frame/3/trmm/bli_trmm_front.c +++ b/frame/3/trmm/bli_trmm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,10 +52,6 @@ void bli_trmm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_trmm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -68,6 +64,14 @@ void bli_trmm_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( b, &c_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); + // We do not explicitly implement the cases where A is transposed. // However, we can still handle them. Specifically, if A is marked as // needing a transposition, we simply induce a transposition. This @@ -85,11 +89,25 @@ void bli_trmm_front bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); } -#if 0 - // NOTE: This case casts right-side trmm in terms of left side. This - // reduces the number of macrokernels exercised to two (trmm_ll and - // trmm_lu) but can lead to the microkernel being executed with an - // output matrix that is stored counter to its output preference. +#ifdef BLIS_DISABLE_TRMM_RIGHT + // NOTE: This case casts right-side trmm in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from triangular + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the triangular + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_TRMM_RIGHT. + + // NOTE: This case casts right-side trmm in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // NOTE: Casting right-side trmm in terms of left side reduces the number + // of macrokernels exercised to two (trmm_ll and trmm_lu). // If A is being multiplied from the right, transpose all operands // so that we can perform the computation as if A were being multiplied @@ -115,7 +133,8 @@ void bli_trmm_front // micro-kernel to access elements of C in its preferred manner. // NOTE: We disable the optimization for 1x1 matrices since the concept // of row- vs. column storage breaks down. - if ( !bli_obj_is_1x1( &c_local ) ) + //if ( !bli_obj_is_1x1( &c_local ) ) // NOTE: This conditional should NOT + // be enabled. See issue #342 comments. if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) { bli_toggle_side( &side ); @@ -133,12 +152,8 @@ void bli_trmm_front #endif - // Set each alias as the root object. - // NOTE: We MUST wait until we are done potentially swapping the objects - // before setting the root fields! - bli_obj_set_as_root( &a_local ); - bli_obj_set_as_root( &b_local ); - bli_obj_set_as_root( &c_local ); + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any @@ -153,29 +168,10 @@ void bli_trmm_front rntm ); - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - // Invoke the internal back-end. bli_l3_thread_decorator ( - bli_gemm_int, + bli_l3_int, BLIS_TRMM, // operation family id alpha, &a_local, diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index 98e62926c7..646287f931 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -203,9 +191,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -243,30 +228,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else if ( bli_is_rih_packed( schema_a ) ) { ss_a_num = 1; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region above where the diagonal of A intersects the left edge of the block, adjust the pointer to C and treat this case as @@ -281,10 +242,6 @@ void PASTEMAC(ch,varname) \ diagoffa = 0; \ c_cast = c_cast + (i )*rs_c; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -317,9 +274,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -337,8 +291,8 @@ void PASTEMAC(ch,varname) \ dim_t jr_inc; \ \ /* Determine the thread range and increment for the 2nd loop. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. \ + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. \ NOTE: Parallelism in the 1st loop is disabled for now. */ \ bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ /*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \ @@ -387,12 +341,12 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1011 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* NOTE: ir loop parallelism disabled for now. */ \ /*if ( bli_trmm_my_iter( i, ir_thread ) ) {*/ \ \ - b1_i = b1 + ( off_a1011 * PACKNR ) / off_scl; \ + b1_i = b1 + off_a1011 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -409,51 +363,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1011, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1011, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a1011, \ + alpha_cast, \ + a1, \ + b1_i, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += ps_a_cur; \ @@ -480,46 +403,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index 6246041418..9ef2a475de 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -203,9 +191,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -243,30 +228,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else if ( bli_is_rih_packed( schema_a ) ) { ss_a_num = 1; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of A intersects the top edge of the block, adjust the pointer to B and @@ -278,7 +239,7 @@ void PASTEMAC(ch,varname) \ i = diagoffa; \ k = k - i; \ diagoffa = 0; \ - b_cast = b_cast + ( i * PACKNR ) / off_scl; \ + b_cast = b_cast + i * PACKNR; \ } \ \ /* If there is a zero region below where the diagonal of A intersects the @@ -288,10 +249,6 @@ void PASTEMAC(ch,varname) \ { \ m = -diagoffa + k; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -324,9 +281,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -394,12 +348,12 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1112 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* NOTE: ir loop parallelism disabled for now. */ \ /*if ( bli_trmm_my_iter( i, ir_thread ) ) {*/ \ \ - b1_i = b1 + ( off_a1112 * PACKNR ) / off_scl; \ + b1_i = b1 + off_a1112 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -416,51 +370,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1112, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1112, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a1112, \ + alpha_cast, \ + a1, \ + b1_i, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += ps_a_cur; \ @@ -487,46 +410,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 117cf63c55..f6b20af2e5 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -203,9 +191,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -243,30 +228,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else if ( bli_is_rih_packed( schema_b ) ) { ss_b_num = 1; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region above where the diagonal of B intersects the left edge of the panel, adjust the pointer to A and treat this @@ -278,7 +239,7 @@ void PASTEMAC(ch,varname) \ j = -diagoffb; \ k = k - j; \ diagoffb = 0; \ - a_cast = a_cast + ( j * PACKMR ) / off_scl; \ + a_cast = a_cast + j * PACKMR; \ } \ \ /* If there is a zero region to the right of where the diagonal @@ -288,10 +249,6 @@ void PASTEMAC(ch,varname) \ { \ n = diagoffb + k; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -324,9 +281,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of A to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ \ @@ -365,9 +319,9 @@ void PASTEMAC(ch,varname) \ \ /* Determine the thread range and increment for the 2nd and 1st loops for the initial rectangular region of B (if it exists). - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. \ - NOTE: Parallelism in the 1st loop is disabled for now. */ \ + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. \ + NOTE: Parallelism in the 1st loop is disabled for now. */ \ bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ \ @@ -387,10 +341,6 @@ void PASTEMAC(ch,varname) \ b2 = b1; \ \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = ir_start; i < ir_end; i += ir_inc ) \ { \ @@ -416,42 +366,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -504,13 +432,9 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_b_cur = k_b1121 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ + ps_b_cur = is_b_cur; \ \ if ( bli_trmm_my_iter_rr( j, thread ) ) { \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( is_b_cur, &aux ); \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -522,7 +446,7 @@ void PASTEMAC(ch,varname) \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ - a1_i = a1 + ( off_b1121 * PACKMR ) / off_scl; \ + a1_i = a1 + off_b1121 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -539,47 +463,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b1121, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b1121, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b1121, \ + alpha_cast, \ + a1_i, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index ea59959c79..f71fb3c4d8 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -203,9 +191,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -243,30 +228,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else if ( bli_is_rih_packed( schema_b ) ) { ss_b_num = 1; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of B intersects the top edge of the panel, adjust the pointer to C and @@ -289,10 +250,6 @@ void PASTEMAC(ch,varname) \ { \ k = -diagoffb + n; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -325,9 +282,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of A to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -409,13 +363,9 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_b_cur = k_b0111 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ + ps_b_cur = is_b_cur; \ \ if ( bli_trmm_my_iter_rr( j, thread ) ) { \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( is_b_cur, &aux ); \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -427,7 +377,7 @@ void PASTEMAC(ch,varname) \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ - a1_i = a1 + ( off_b0111 * PACKMR ) / off_scl; \ + a1_i = a1 + off_b0111 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -444,47 +394,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b0111, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b0111, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b0111, \ + alpha_cast, \ + a1_i, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ @@ -510,9 +433,9 @@ void PASTEMAC(ch,varname) \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ \ /* Advance the start and end iteration offsets for the rectangular region - by the number of iterations used for the triangular region. */ \ - jr_start += n_iter_tri; \ - jr_end += n_iter_tri; \ + by the number of iterations used for the triangular region. */ \ + jr_start += n_iter_tri; \ + jr_end += n_iter_tri; \ jb0 = n_iter_tri; \ \ /* Save the resulting value of b1 from the previous loop since it represents @@ -530,7 +453,7 @@ void PASTEMAC(ch,varname) \ the starting address of the rectangular region (which is already n_iter_tri logical iterations through B). */ \ b1 = b_cast + (j-jb0) * cstep_b; \ - c1 = c_cast + j * cstep_c; \ + c1 = c_cast + j * cstep_c; \ \ n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ \ @@ -542,10 +465,6 @@ void PASTEMAC(ch,varname) \ This allows the current macro-kernel to work for both trmm and trmm3. */ \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = ir_start; i < ir_end; i += ir_inc ) \ { \ @@ -571,42 +490,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/trmm/bli_trmm_var.h b/frame/3/trmm/bli_trmm_var.h index 09694ca5ca..262b0490fd 100644 --- a/frame/3/trmm/bli_trmm_var.h +++ b/frame/3/trmm/bli_trmm_var.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/trmm/bli_trmm_xx_ker_var2.c b/frame/3/trmm/bli_trmm_xx_ker_var2.c index 343aaa078b..898cfe2423 100644 --- a/frame/3/trmm/bli_trmm_xx_ker_var2.c +++ b/frame/3/trmm/bli_trmm_xx_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" -static gemm_var_oft vars[2][2] = +static l3_var_oft vars[2][2] = { { bli_trmm_ll_ker_var2, bli_trmm_lu_ker_var2 }, { bli_trmm_rl_ker_var2, bli_trmm_ru_ker_var2 } @@ -52,9 +52,9 @@ void bli_trmm_xx_ker_var2 thrinfo_t* thread ) { - bool_t side; - bool_t uplo; - gemm_var_oft f; + dim_t side; + dim_t uplo; + l3_var_oft f; // Set two bools: one based on the implied side parameter (the structure // of the root object) and one based on the uplo field of the triangular diff --git a/frame/3/trmm/other/bli_trmm_ll_ker_var2.c b/frame/3/trmm/other/bli_trmm_ll_ker_var2.c index 3747a0dcf4..9ab64e470d 100644 --- a/frame/3/trmm/other/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/other/bli_trmm_ll_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -175,7 +175,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_ll_ker_var2rr.c b/frame/3/trmm/other/bli_trmm_ll_ker_var2rr.c index ea979d7c3b..6fef4e0c96 100644 --- a/frame/3/trmm/other/bli_trmm_ll_ker_var2rr.c +++ b/frame/3/trmm/other/bli_trmm_ll_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_ll_ker_var2sl.c b/frame/3/trmm/other/bli_trmm_ll_ker_var2sl.c index e612b340cd..e0d9cc75f7 100644 --- a/frame/3/trmm/other/bli_trmm_ll_ker_var2sl.c +++ b/frame/3/trmm/other/bli_trmm_ll_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_lu_ker_var2.c b/frame/3/trmm/other/bli_trmm_lu_ker_var2.c index 9a4e36b656..0abcfd77ae 100644 --- a/frame/3/trmm/other/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/other/bli_trmm_lu_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -175,7 +175,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_lu_ker_var2rr.c b/frame/3/trmm/other/bli_trmm_lu_ker_var2rr.c index 551bc097d0..8c505f88a7 100644 --- a/frame/3/trmm/other/bli_trmm_lu_ker_var2rr.c +++ b/frame/3/trmm/other/bli_trmm_lu_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_lu_ker_var2sl.c b/frame/3/trmm/other/bli_trmm_lu_ker_var2sl.c index 132c732d6f..3bb0deaa30 100644 --- a/frame/3/trmm/other/bli_trmm_lu_ker_var2sl.c +++ b/frame/3/trmm/other/bli_trmm_lu_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_rl_ker_var2.c b/frame/3/trmm/other/bli_trmm_rl_ker_var2.c index b29df08508..672caaa052 100644 --- a/frame/3/trmm/other/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/other/bli_trmm_rl_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -175,7 +175,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_rl_ker_var2rr.c b/frame/3/trmm/other/bli_trmm_rl_ker_var2rr.c index 14b2359182..9d9e3809cd 100644 --- a/frame/3/trmm/other/bli_trmm_rl_ker_var2rr.c +++ b/frame/3/trmm/other/bli_trmm_rl_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_rl_ker_var2sl.c b/frame/3/trmm/other/bli_trmm_rl_ker_var2sl.c index cf4a6e0865..8bac0ec4aa 100644 --- a/frame/3/trmm/other/bli_trmm_rl_ker_var2sl.c +++ b/frame/3/trmm/other/bli_trmm_rl_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_ru_ker_var2.c b/frame/3/trmm/other/bli_trmm_ru_ker_var2.c index 602f4cc3b2..fc2991b132 100644 --- a/frame/3/trmm/other/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/other/bli_trmm_ru_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -175,7 +175,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_ru_ker_var2rr.c b/frame/3/trmm/other/bli_trmm_ru_ker_var2rr.c index 03eaa6ea69..00a0dc3f0c 100644 --- a/frame/3/trmm/other/bli_trmm_ru_ker_var2rr.c +++ b/frame/3/trmm/other/bli_trmm_ru_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm/other/bli_trmm_ru_ker_var2sl.c b/frame/3/trmm/other/bli_trmm_ru_ker_var2sl.c index 2411a24a4b..889fa49fa7 100644 --- a/frame/3/trmm/other/bli_trmm_ru_ker_var2sl.c +++ b/frame/3/trmm/other/bli_trmm_ru_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -178,7 +178,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trmm3/bli_trmm3_front.c b/frame/3/trmm3/bli_trmm3_front.c index 39067ac0bc..3b97539603 100644 --- a/frame/3/trmm3/bli_trmm3_front.c +++ b/frame/3/trmm3/bli_trmm3_front.c @@ -53,10 +53,6 @@ void bli_trmm3_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_trmm_check( side, alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -69,6 +65,14 @@ void bli_trmm3_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); + // We do not explicitly implement the cases where A is transposed. // However, we can still handle them. Specifically, if A is marked as // needing a transposition, we simply induce a transposition. This @@ -86,7 +90,25 @@ void bli_trmm3_front bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); } -#if 0 +#ifdef BLIS_DISABLE_TRMM3_RIGHT + // NOTE: This case casts right-side trmm3 in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from triangular + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the triangular + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_TRMM3_RIGHT. + + // NOTE: This case casts right-side trmm3 in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // NOTE: Casting right-side trmm3 in terms of left side reduces the number + // of macrokernels exercised to two (trmm_ll and trmm_lu). // If A is being multiplied from the right, transpose all operands // so that we can perform the computation as if A were being multiplied @@ -122,12 +144,8 @@ void bli_trmm3_front #endif - // Set each alias as the root object. - // NOTE: We MUST wait until we are done potentially swapping the objects - // before setting the root fields! - bli_obj_set_as_root( &a_local ); - bli_obj_set_as_root( &b_local ); - bli_obj_set_as_root( &c_local ); + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any @@ -142,29 +160,10 @@ void bli_trmm3_front rntm ); - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - // Invoke the internal back-end. bli_l3_thread_decorator ( - bli_gemm_int, + bli_l3_int, BLIS_TRMM, // operation family id alpha, &a_local, diff --git a/frame/3/trsm/bli_trsm.h b/frame/3/trsm/bli_trsm.h index 00b604de6e..964422d017 100644 --- a/frame/3/trsm/bli_trsm.h +++ b/frame/3/trsm/bli_trsm.h @@ -34,7 +34,5 @@ #include "bli_trsm_cntl.h" #include "bli_trsm_front.h" -#include "bli_trsm_int.h" - #include "bli_trsm_var.h" diff --git a/frame/3/trsm/bli_trsm_blk_var1.c b/frame/3/trsm/bli_trsm_blk_var1.c index 1bab54d5ff..30bf6921cd 100644 --- a/frame/3/trsm/bli_trsm_blk_var1.c +++ b/frame/3/trsm/bli_trsm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,7 +58,7 @@ void bli_trsm_blk_var1 bli_l3_prune_unref_mparts_m( a, b, c, cntl ); // Isolate the diagonal block A11 and its corresponding row panel C1. - const dim_t kc = bli_obj_width( a ); + const dim_t kc = bli_obj_width_after_trans( a ); obj_t a11, c1; bli_acquire_mpart_mdim( direct, BLIS_SUBPART1, 0, kc, a, &a11 ); @@ -96,7 +96,7 @@ void bli_trsm_blk_var1 #endif // Perform trsm subproblem. - bli_trsm_int + bli_l3_int ( &BLIS_ONE, &a11_1, @@ -117,7 +117,7 @@ void bli_trsm_blk_var1 // We must execute a barrier here because the upcoming rank-k update // requires the packed matrix B to be fully updated by the trsm // subproblem. - bli_thread_obarrier( thread ); + bli_thread_barrier( thread ); // Isolate the remaining part of the column panel matrix A, which we do by // acquiring the subpartition ahead of A11 (that is, A21 or A01, depending @@ -169,7 +169,7 @@ void bli_trsm_blk_var1 // Perform gemm subproblem. (Note that we use the same backend // function as before, since we're calling the same macrokernel.) - bli_trsm_int + bli_l3_int ( &BLIS_ONE, &a11, diff --git a/frame/3/trsm/bli_trsm_blk_var2.c b/frame/3/trsm/bli_trsm_blk_var2.c index c8330b8013..5691c964ad 100644 --- a/frame/3/trsm/bli_trsm_blk_var2.c +++ b/frame/3/trsm/bli_trsm_blk_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -60,7 +60,7 @@ void bli_trsm_blk_var2 bli_thread_range_ndim ( direct, thread, a, b, c, cntl, cntx, - &my_start, &my_end + &my_start, &my_end ); // Partition along the n dimension. @@ -77,7 +77,7 @@ void bli_trsm_blk_var2 i, b_alg, c, &c1 ); // Perform trsm subproblem. - bli_trsm_int + bli_l3_int ( &BLIS_ONE, a, diff --git a/frame/3/trsm/bli_trsm_blk_var3.c b/frame/3/trsm/bli_trsm_blk_var3.c index ee7c2f9acf..43fc25f16d 100644 --- a/frame/3/trsm/bli_trsm_blk_var3.c +++ b/frame/3/trsm/bli_trsm_blk_var3.c @@ -71,7 +71,7 @@ void bli_trsm_blk_var3 i, b_alg, b, &b1 ); // Perform trsm subproblem. - bli_trsm_int + bli_l3_int ( &BLIS_ONE, &a1, @@ -85,7 +85,7 @@ void bli_trsm_blk_var3 ); //bli_thread_ibarrier( thread ); - bli_thread_obarrier( bli_thrinfo_sub_node( thread ) ); + bli_thread_barrier( bli_thrinfo_sub_node( thread ) ); // This variant executes multiple rank-k updates. Therefore, if the // internal alpha scalars on A/B and C are non-zero, we must ensure diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index 9b59cae61d..0a3be87f74 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,32 +40,30 @@ cntl_t* bli_trsm_cntl_create rntm_t* rntm, side_t side, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { if ( bli_is_left( side ) ) - return bli_trsm_l_cntl_create( rntm, schema_a, schema_b ); + return bli_trsm_l_cntl_create( rntm, schema_a, schema_b, ker ); else - return bli_trsm_r_cntl_create( rntm, schema_a, schema_b ); + return bli_trsm_r_cntl_create( rntm, schema_a, schema_b, ker ); } cntl_t* bli_trsm_l_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { - void* macro_kernel_p; - void* packa_fp; - void* packb_fp; + void_fp macro_kernel_p; - // Use the function pointer to the macrokernels that use slab - // assignment of micropanels to threads in the jr and ir loops. + // Set the default macrokernel. If a non-NULL kernel function pointer is + // passed in, we use that instead. macro_kernel_p = bli_trsm_xx_ker_var2; - - packa_fp = bli_packm_blk_var1; - packb_fp = bli_packm_blk_var1; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; @@ -95,11 +93,10 @@ cntl_t* bli_trsm_l_cntl_create cntl_t* gemm_cntl_packa = bli_packm_cntl_create_node ( rntm, - bli_trsm_packa, // trsm operation's packm function for A. - packa_fp, + bli_l3_packa, // trsm operation's packm function for A. BLIS_MR, BLIS_MR, - TRUE, // do NOT invert diagonal + FALSE, // do NOT invert diagonal TRUE, // reverse iteration if upper? FALSE, // reverse iteration if lower? schema_a, // normally BLIS_PACKED_ROW_PANELS @@ -133,11 +130,14 @@ cntl_t* bli_trsm_l_cntl_create cntl_t* trsm_cntl_packa = bli_packm_cntl_create_node ( rntm, - bli_trsm_packa, // trsm operation's packm function for A. - packa_fp, + bli_l3_packa, // trsm operation's packm function for A. BLIS_MR, BLIS_MR, - TRUE, // do NOT invert diagonal +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + TRUE, // invert diagonal +#else + FALSE, // do NOT invert diagonal +#endif TRUE, // reverse iteration if upper? FALSE, // reverse iteration if lower? schema_a, // normally BLIS_PACKED_ROW_PANELS @@ -167,10 +167,9 @@ cntl_t* bli_trsm_l_cntl_create cntl_t* trsm_cntl_packb = bli_packm_cntl_create_node ( rntm, - bli_trsm_packb, - packb_fp, - BLIS_MR, + bli_l3_packb, BLIS_NR, + BLIS_MR, FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? FALSE, // reverse iteration if lower? @@ -204,16 +203,17 @@ cntl_t* bli_trsm_l_cntl_create cntl_t* bli_trsm_r_cntl_create ( - rntm_t* rntm, + rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { // NOTE: trsm macrokernels are presently disabled for right-side execution. - void* macro_kernel_p = bli_trsm_xx_ker_var2; - - void* packa_fp = bli_packm_blk_var1; - void* packb_fp = bli_packm_blk_var1; + // Set the default macrokernel. If a non-NULL kernel function pointer is + // passed in, we use that instead. + void_fp macro_kernel_p = bli_trsm_xx_ker_var2; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; @@ -240,8 +240,7 @@ cntl_t* bli_trsm_r_cntl_create cntl_t* trsm_cntl_packa = bli_packm_cntl_create_node ( rntm, - bli_trsm_packa, - packa_fp, + bli_l3_packa, BLIS_NR, BLIS_MR, FALSE, // do NOT invert diagonal @@ -266,8 +265,7 @@ cntl_t* bli_trsm_r_cntl_create cntl_t* trsm_cntl_packb = bli_packm_cntl_create_node ( rntm, - bli_trsm_packb, - packb_fp, + bli_l3_packb, BLIS_MR, BLIS_MR, TRUE, // do NOT invert diagonal @@ -318,7 +316,7 @@ cntl_t* bli_trsm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/frame/3/trsm/bli_trsm_cntl.h b/frame/3/trsm/bli_trsm_cntl.h index 17b8d3c18f..86f4a29b2a 100644 --- a/frame/3/trsm/bli_trsm_cntl.h +++ b/frame/3/trsm/bli_trsm_cntl.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,21 +38,24 @@ cntl_t* bli_trsm_cntl_create rntm_t* rntm, side_t side, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); cntl_t* bli_trsm_l_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); cntl_t* bli_trsm_r_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); void bli_trsm_cntl_free @@ -69,7 +72,7 @@ cntl_t* bli_trsm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index 5093d1a4a5..7f3d17aeff 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -51,9 +52,12 @@ void bli_trsm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_trsm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); +#if 0 +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + gint_t status = bli_trsm_small( side, alpha, a, b, cntx, cntl ); + if ( status == BLIS_SUCCESS ) return; +#endif +#endif // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) @@ -67,6 +71,14 @@ void bli_trsm_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( b, &c_local ); + // Set the obj_t buffer field to the location currently implied by the row + // and column offsets and then zero the offsets. If any of the original + // obj_t's were views into larger matrices, this step effectively makes + // those obj_t's "forget" their lineage. + bli_obj_reset_origin( &a_local ); + bli_obj_reset_origin( &b_local ); + bli_obj_reset_origin( &c_local ); + // We do not explicitly implement the cases where A is transposed. // However, we can still handle them. Specifically, if A is marked as // needing a transposition, we simply induce a transposition. This @@ -114,12 +126,8 @@ void bli_trsm_front #endif - // Set each alias as the root object. - // NOTE: We MUST wait until we are done potentially swapping the objects - // before setting the root fields! - bli_obj_set_as_root( &a_local ); - bli_obj_set_as_root( &b_local ); - bli_obj_set_as_root( &c_local ); + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any @@ -134,29 +142,10 @@ void bli_trsm_front rntm ); - // A sort of hack for communicating the desired pach schemas for A and B - // to bli_trsm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - // Invoke the internal back-end. bli_l3_thread_decorator ( - bli_trsm_int, + bli_l3_int, BLIS_TRSM, // operation family id alpha, &a_local, diff --git a/frame/3/trsm/bli_trsm_front.h b/frame/3/trsm/bli_trsm_front.h index 1a08b7c75e..379935536a 100644 --- a/frame/3/trsm/bli_trsm_front.h +++ b/frame/3/trsm/bli_trsm_front.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,3 +43,16 @@ void bli_trsm_front rntm_t* rntm, cntl_t* cntl ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX +err_t bli_trsm_small + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +#endif + diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index 37823d7bfd..f50f739e73 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -176,14 +176,15 @@ void PASTEMAC(ch,varname) \ temporary buffer are set so that they match the storage of the original C matrix. For example, if C is column-stored, ct will be column-stored as well. */ \ +/* ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +*/ \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -209,9 +210,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -249,29 +247,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region above where the diagonal of A intersects the left edge of the block, adjust the pointer to C and treat this case as @@ -303,10 +278,6 @@ void PASTEMAC(ch,varname) \ know that the underlying buffer was already allocated to have an m dimension that is a multiple of PACKMR, with the region between the last row and the next multiple of MR zero-padded accordingly. */ \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -339,9 +310,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* We don't bother querying the thrinfo_t node for the 1st loop because we can't parallelize that loop in trsm due to the inter-iteration @@ -411,18 +379,18 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1011 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* Compute the addresses of the panel A10 and the triangular block A11. */ \ a10 = a1; \ - /* a11 = a1 + ( k_a10 * PACKMR ) / off_scl; */ \ - a11 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a10 * PACKMR, off_scl ); \ + a11 = a1 + k_a10 * PACKMR; \ + /*a11 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a10 * PACKMR, 1 );*/ \ \ /* Compute the addresses of the panel B01 and the block B11. */ \ - b01 = b1 + ( off_a10 * PACKNR ) / off_scl; \ - b11 = b1 + ( off_a11 * PACKNR ) / off_scl; \ + b01 = b1 + off_a10 * PACKNR; \ + b11 = b1 + off_a11 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + ps_a_cur; \ @@ -439,48 +407,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_a10, \ - alpha1_cast, \ - a10, \ - a11, \ - b01, \ - b11, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_a10, \ - alpha1_cast, \ - a10, \ - a11, \ - b01, \ - b11, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the bottom edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + gemmtrsm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a10, \ + alpha1_cast, \ + a10, \ + a11, \ + b01, \ + b11, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += ps_a_cur; \ } \ @@ -503,47 +443,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - alpha2_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + a1, \ + b1, \ + alpha2_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += rstep_a; \ } \ @@ -553,44 +466,11 @@ void PASTEMAC(ch,varname) \ } \ \ /* -if ( bli_is_4mi_packed( schema_a ) ){ \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_r before", k, n, \ - ( double* )b, rs_b, 1, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_i before", k, n, \ - ( double* )b+72, rs_b, 1, "%4.1f", "" ); \ -}else{ \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_r before", k, n, \ - ( double* )b, 2*rs_b, 2, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_i before", k, n, \ - ( double* )b+1, 2*rs_b, 2, "%4.1f", "" ); \ -} \ -*/ \ -\ -/* PASTEMAC(d,fprintm)( stdout, "trsm_ll_ker_var2: a11p_r computed", MR, MR, \ ( double* )a11, 1, PACKMR, "%4.1f", "" ); \ */ \ \ /* -if ( bli_is_4mi_packed( schema_a ) ){ \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_r after", k, n, \ - ( double* )b, rs_b, 1, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_i after", k, n, \ - ( double* )b+72, rs_b, 1, "%4.1f", "" ); \ -}else{ \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_r after", k, n, \ - ( double* )b, 2*rs_b, 2, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_i after", k, n, \ - ( double* )b+1, 2*rs_b, 2, "%4.1f", "" ); \ -} \ - -PASTEMAC(d,fprintm)( stdout, "trsm_ll_ker_var2: b_r", m, n, \ - ( double* )c, 1, cs_c, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsm_ll_ker_var2: b_i", m, n, \ - ( double* )c + 8*9, 1, cs_c, "%4.1f", "" ); \ -*/ \ -\ -/* PASTEMAC(ch,fprintm)( stdout, "trsm_ll_ker_var2: a1 (diag)", MR, k_a1011, a1, 1, MR, "%5.2f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "trsm_ll_ker_var2: a11 (diag)", MR, MR, a11, 1, MR, "%5.2f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "trsm_ll_ker_var2: b1 (diag)", k_a1011, NR, bp_i, NR, 1, "%5.2f", "" ); \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 853bccf919..4f35141435 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -176,14 +176,15 @@ void PASTEMAC(ch,varname) \ temporary buffer are set so that they match the storage of the original C matrix. For example, if C is column-stored, ct will be column-stored as well. */ \ +/* ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +*/ \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -210,9 +211,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -250,29 +248,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of A intersects the top edge of the block, adjust the pointer to B and @@ -284,7 +259,7 @@ void PASTEMAC(ch,varname) \ i = diagoffa; \ k = k - i; \ diagoffa = 0; \ - b_cast = b_cast + ( i * PACKNR ) / off_scl; \ + b_cast = b_cast + i * PACKNR; \ } \ \ /* If there is a zero region below where the diagonal of A intersects the @@ -311,10 +286,6 @@ void PASTEMAC(ch,varname) \ know that the underlying buffer was already allocated to have an m dimension that is a multiple of PACKMR, with the region between the last row and the next multiple of MR zero-padded accordingly. */ \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -347,9 +318,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* We don't bother querying the thrinfo_t node for the 1st loop because we can't parallelize that loop in trsm due to the inter-iteration @@ -421,18 +389,18 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1112 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* Compute the addresses of the triangular block A11 and the panel A12. */ \ a11 = a1; \ - /* a12 = a1 + ( k_a11 * PACKMR ) / off_scl; */ \ - a12 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a11 * PACKMR, off_scl ); \ + a12 = a1 + k_a11 * PACKMR; \ + /*a12 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a11 * PACKMR, 1 );*/ \ \ /* Compute the addresses of the panel B01 and the block B11. */ \ - b11 = b1 + ( off_a11 * PACKNR ) / off_scl; \ - b21 = b1 + ( off_a12 * PACKNR ) / off_scl; \ + b11 = b1 + off_a11 * PACKNR; \ + b21 = b1 + off_a12 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + ps_a_cur; \ @@ -449,48 +417,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_a12, \ - alpha1_cast, \ - a12, \ - a11, \ - b21, \ - b11, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_a12, \ - alpha1_cast, \ - a12, \ - a11, \ - b21, \ - b11, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the bottom edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + gemmtrsm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a12, \ + alpha1_cast, \ + a12, \ + a11, \ + b21, \ + b11, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += ps_a_cur; \ } \ @@ -513,47 +453,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - alpha2_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + a1, \ + b1, \ + alpha2_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += rstep_a; \ } \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 87e1a0b287..b4937134fb 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -181,14 +181,15 @@ void PASTEMAC(ch,varname) \ temporary buffer are set so that they match the storage of the original C matrix. For example, if C is column-stored, ct will be column-stored as well. */ \ +/* ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +*/ \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -215,9 +216,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -263,29 +261,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % NR != 0 ? k + NR - ( k % NR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region above where the diagonal of B intersects the left edge of the panel, adjust the pointer to A and treat this @@ -297,7 +272,7 @@ void PASTEMAC(ch,varname) \ j = -diagoffb; \ k = k - j; \ diagoffb = 0; \ - a_cast = a_cast + ( j * PACKMR ) / off_scl; \ + a_cast = a_cast + j * PACKMR; \ } \ \ /* If there is a zero region to the right of where the diagonal @@ -329,10 +304,6 @@ void PASTEMAC(ch,varname) \ know that the underlying buffer was already allocated to have an n dimension that is a multiple of PACKNR, with the region between the last column and the next multiple of NR zero-padded accordingly. */ \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -369,9 +340,6 @@ void PASTEMAC(ch,varname) \ NOTE: We swap the values for A and B since the triangular "A" matrix is actually contained within B. */ \ bli_auxinfo_set_is_b( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ b1 = b_cast; \ c1 = c_cast; \ @@ -413,20 +381,14 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the triangular block B11 and the panel B21. */ \ - b11 = b1; \ - /* b21 = b1 + ( k_b11 * PACKNR ) / off_scl; */ \ - b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, off_scl ); \ + b11 = b1; \ + b21 = b1 + k_b11 * PACKNR; \ + /*b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, 1 );*/ \ \ /* Compute the panel stride for the current micro-panel. */ \ is_b_cur = k_b1121 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( is_b_cur, &aux ); \ + ps_b_cur = is_b_cur; \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -440,8 +402,8 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the A11 block and A12 panel. */ \ - a11 = a1 + ( off_b11 * PACKMR ) / off_scl; \ - a12 = a1 + ( off_b21 * PACKMR ) / off_scl; \ + a11 = a1 + off_b11 * PACKMR; \ + a12 = a1 + off_b21 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -460,44 +422,21 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_b21, \ - alpha1_cast, \ - b21, \ - b11, \ - a12, \ - a11, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_b21, \ - alpha1_cast, \ - b21, \ - b11, \ - a12, \ - a11, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the bottom edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + gemmtrsm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b21, \ + alpha1_cast, \ + b21, \ + b11, \ + a12, \ + a11, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ +\ } \ \ a1 += rstep_a; \ @@ -508,12 +447,6 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_below_diag_n( diagoffb_j, k, NR ) ) \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ { \ @@ -540,43 +473,21 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - alpha2_cast, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - zero, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + b1, \ + a1, \ + alpha2_cast, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ +\ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 71a72ea240..09942d311a 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -181,14 +181,15 @@ void PASTEMAC(ch,varname) \ temporary buffer are set so that they match the storage of the original C matrix. For example, if C is column-stored, ct will be column-stored as well. */ \ +/* ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +*/ \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -214,9 +215,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -262,29 +260,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % NR != 0 ? k + NR - ( k % NR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of B intersects the top edge of the panel, adjust the pointer to C and @@ -324,10 +299,6 @@ void PASTEMAC(ch,varname) \ know that the underlying buffer was already allocated to have an n dimension that is a multiple of PACKNR, with the region between the last column and the next multiple of NR zero-padded accordingly. */ \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -364,9 +335,6 @@ void PASTEMAC(ch,varname) \ NOTE: We swap the values for A and B since the triangular "A" matrix is actually contained within B. */ \ bli_auxinfo_set_is_b( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ b1 = b_cast; \ c1 = c_cast; \ @@ -406,20 +374,14 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the panel B10 and the triangular block B11. */ \ - b01 = b1; \ - /* b11 = b1 + ( k_b01 * PACKNR ) / off_scl; */ \ - b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, off_scl ); \ + b01 = b1; \ + b11 = b1 + k_b01 * PACKNR; \ + /*b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, 1 );*/ \ \ /* Compute the panel stride for the current micro-panel. */ \ is_b_cur = k_b0111 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( is_b_cur, &aux ); \ + ps_b_cur = is_b_cur; \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -433,8 +395,8 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the A10 panel and A11 block. */ \ - a10 = a1 + ( off_b01 * PACKMR ) / off_scl; \ - a11 = a1 + ( off_b11 * PACKMR ) / off_scl; \ + a10 = a1 + off_b01 * PACKMR; \ + a11 = a1 + off_b11 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -453,44 +415,21 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_b01, \ - alpha1_cast, \ - b01, \ - b11, \ - a10, \ - a11, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the fused gemm/trsm micro-kernel. */ \ - gemmtrsm_ukr \ - ( \ - k_b01, \ - alpha1_cast, \ - b01, \ - b11, \ - a10, \ - a11, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the bottom edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + gemmtrsm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b01, \ + alpha1_cast, \ + b01, \ + b11, \ + a10, \ + a11, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ +\ } \ \ a1 += rstep_a; \ @@ -501,12 +440,6 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_above_diag_n( diagoffb_j, k, NR ) ) \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ { \ @@ -533,43 +466,21 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - alpha2_cast, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - zero, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + b1, \ + a1, \ + alpha2_cast, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ +\ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trsm/bli_trsm_var.h b/frame/3/trsm/bli_trsm_var.h index 0f5f42de87..8322a8b5b6 100644 --- a/frame/3/trsm/bli_trsm_var.h +++ b/frame/3/trsm/bli_trsm_var.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,8 +55,6 @@ void PASTEMAC0(opname) \ GENPROT( trsm_blk_var1 ) GENPROT( trsm_blk_var2 ) GENPROT( trsm_blk_var3 ) -GENPROT( trsm_packa ) -GENPROT( trsm_packb ) GENPROT( trsm_xx_ker_var2 ) diff --git a/frame/3/trsm/bli_trsm_xx_ker_var2.c b/frame/3/trsm/bli_trsm_xx_ker_var2.c index dfdcf2ebad..c30a5828a3 100644 --- a/frame/3/trsm/bli_trsm_xx_ker_var2.c +++ b/frame/3/trsm/bli_trsm_xx_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" -static trsm_var_oft vars[2][2] = +static l3_var_oft vars[2][2] = { { bli_trsm_ll_ker_var2, bli_trsm_lu_ker_var2 }, { bli_trsm_rl_ker_var2, bli_trsm_ru_ker_var2 } @@ -52,9 +52,9 @@ void bli_trsm_xx_ker_var2 thrinfo_t* thread ) { - bool_t side; - bool_t uplo; - trsm_var_oft f; + dim_t side; + dim_t uplo; + l3_var_oft f; // Set two bools: one based on the implied side parameter (the structure // of the root object) and one based on the uplo field of the triangular diff --git a/frame/3/trsm/other/bli_trsm_ll_ker_var2.c b/frame/3/trsm/other/bli_trsm_ll_ker_var2.c index 1c4b0b5c75..dc57eac5f2 100644 --- a/frame/3/trsm/other/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/other/bli_trsm_ll_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -179,7 +179,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_ll_ker_var2rr.c b/frame/3/trsm/other/bli_trsm_ll_ker_var2rr.c index 3891bffc02..38768242ec 100644 --- a/frame/3/trsm/other/bli_trsm_ll_ker_var2rr.c +++ b/frame/3/trsm/other/bli_trsm_ll_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -182,7 +182,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_ll_ker_var2sl.c b/frame/3/trsm/other/bli_trsm_ll_ker_var2sl.c index 1bc2f6e42d..78ffe17585 100644 --- a/frame/3/trsm/other/bli_trsm_ll_ker_var2sl.c +++ b/frame/3/trsm/other/bli_trsm_ll_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -182,7 +182,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_lu_ker_var2.c b/frame/3/trsm/other/bli_trsm_lu_ker_var2.c index 673e1eaa37..7c4cea9763 100644 --- a/frame/3/trsm/other/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/other/bli_trsm_lu_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -179,7 +179,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_lu_ker_var2rr.c b/frame/3/trsm/other/bli_trsm_lu_ker_var2rr.c index 72761ee541..8d050c62b0 100644 --- a/frame/3/trsm/other/bli_trsm_lu_ker_var2rr.c +++ b/frame/3/trsm/other/bli_trsm_lu_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -182,7 +182,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_lu_ker_var2sl.c b/frame/3/trsm/other/bli_trsm_lu_ker_var2sl.c index 491ae8198c..b49a1144ee 100644 --- a/frame/3/trsm/other/bli_trsm_lu_ker_var2sl.c +++ b/frame/3/trsm/other/bli_trsm_lu_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -182,7 +182,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_rl_ker_var2.c b/frame/3/trsm/other/bli_trsm_rl_ker_var2.c index 3293289a19..a11936389c 100644 --- a/frame/3/trsm/other/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/other/bli_trsm_rl_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -184,7 +184,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/3/trsm/other/bli_trsm_ru_ker_var2.c b/frame/3/trsm/other/bli_trsm_ru_ker_var2.c index 9726fd4672..7ad1e42714 100644 --- a/frame/3/trsm/other/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/other/bli_trsm_ru_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -184,7 +184,7 @@ void PASTEMAC(ch,varname) \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype ) ] \ __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ diff --git a/frame/base/bli_apool.c b/frame/base/bli_apool.c index 5dd98206e5..e2d8123511 100644 --- a/frame/base/bli_apool.c +++ b/frame/base/bli_apool.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,12 +39,19 @@ void bli_apool_init apool_t* restrict apool ) { + err_t r_val; + + // NOTE: The apool_t is only used in one place; it is the type used to + // define the sba. We've switched to static initialization of the mutex + // field to remove one more thing that could possibly go wrong during + // library initialization. + // Query the mutex from the apool_t. - bli_pthread_mutex_t* restrict mutex = bli_apool_mutex( apool ); + //bli_pthread_mutex_t* restrict mutex = bli_apool_mutex( apool ); // Initialize the mutex. //*mutex = BLIS_PTHREAD_MUTEX_INITIALIZER; - bli_pthread_mutex_init( mutex, NULL ); + //bli_pthread_mutex_init( mutex, NULL ); // We choose to start with: // - an empty pool @@ -87,7 +94,7 @@ void bli_apool_init // Allocate the block_ptrs array. array_t** restrict block_ptrs = - bli_malloc_intl( block_ptrs_len * sizeof( array_t* ) ); + bli_malloc_intl( block_ptrs_len * sizeof( array_t* ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_apool_init(): allocating %d array_t.\n", ( int )num_blocks ); @@ -136,6 +143,8 @@ void bli_apool_alloc_block array_t** restrict array_p ) { + err_t r_val; + // Since the apool_t is defined as a pool of array_t, we can hard-code // the block_size parameter. const siz_t block_size = sizeof( array_t ); @@ -149,7 +158,7 @@ void bli_apool_alloc_block // be recovered when it's time to free the block. array_t* restrict array = - bli_malloc_intl( block_size ); + bli_malloc_intl( block_size, &r_val ); // Initialize an array_t struct within the newly allocated memory region. bli_array_init( num_elem, sizeof( pool_t* ), array ); @@ -212,11 +221,14 @@ void bli_apool_finalize apool_t* restrict apool ) { + // NOTE: Since the apool_t's mutex is now initialized statically, we no + // longer need to explicitly destroy it. + // Query the mutex from the apool_t. - bli_pthread_mutex_t* restrict mutex = bli_apool_mutex( apool ); + //bli_pthread_mutex_t* restrict mutex = bli_apool_mutex( apool ); // Destroy the mutex. - bli_pthread_mutex_destroy( mutex ); + //bli_pthread_mutex_destroy( mutex ); // Query the underlying pool_t and mutex from the apool_t. pool_t* restrict pool = bli_apool_pool( apool ); @@ -368,6 +380,8 @@ pool_t* bli_apool_array_elem array_t* restrict array ) { + err_t r_val; + // Query the array element corresponding to index. // NOTE: If we knew that the array_t contained elements of size // sizeof( void* ) or sizeof( whatever ), we could return the *value* @@ -389,6 +403,7 @@ pool_t* bli_apool_array_elem const siz_t num_blocks = 1; const siz_t block_ptrs_len = 25; const siz_t align_size = 16; + const siz_t offset_size = 0; malloc_ft malloc_fp = BLIS_MALLOC_POOL; free_ft free_fp = BLIS_FREE_POOL; @@ -416,7 +431,7 @@ pool_t* bli_apool_array_elem #endif // Allocate the pool_t. - pool = bli_malloc_intl( sizeof( pool_t ) ); + pool = bli_malloc_intl( sizeof( pool_t ), &r_val ); // Initialize the pool_t. bli_pool_init @@ -425,6 +440,7 @@ pool_t* bli_apool_array_elem block_ptrs_len, block_size, align_size, + offset_size, malloc_fp, free_fp, pool @@ -451,6 +467,8 @@ void bli_apool_grow apool_t* restrict apool ) { + err_t r_val; + // If the requested increase is zero, return early. if ( num_blocks_add == 0 ) return; @@ -491,7 +509,7 @@ void bli_apool_grow // Allocate a new block_ptrs array. array_t** restrict block_ptrs_new = - bli_malloc_intl( block_ptrs_len_new * sizeof( array_t* ) ); + bli_malloc_intl( block_ptrs_len_new * sizeof( array_t* ), &r_val ); // Query the top_index of the pool. const siz_t top_index = bli_pool_top_index( pool ); diff --git a/frame/base/bli_apool.h b/frame/base/bli_apool.h index 1f7889023c..e6e91958af 100644 --- a/frame/base/bli_apool.h +++ b/frame/base/bli_apool.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -51,22 +51,22 @@ typedef struct // apool entry query -static pool_t* bli_apool_pool( apool_t* apool ) +BLIS_INLINE pool_t* bli_apool_pool( apool_t* apool ) { return &(apool->pool); } -static bli_pthread_mutex_t* bli_apool_mutex( apool_t* apool ) +BLIS_INLINE bli_pthread_mutex_t* bli_apool_mutex( apool_t* apool ) { return &(apool->mutex); } -static siz_t bli_apool_def_array_len( apool_t* pool ) +BLIS_INLINE siz_t bli_apool_def_array_len( apool_t* pool ) { return pool->def_array_len; } -static bool_t bli_apool_is_exhausted( apool_t* apool ) +BLIS_INLINE bool bli_apool_is_exhausted( apool_t* apool ) { pool_t* restrict pool = bli_apool_pool( apool ); @@ -75,19 +75,19 @@ static bool_t bli_apool_is_exhausted( apool_t* apool ) // apool action -static void bli_apool_lock( apool_t* apool ) +BLIS_INLINE void bli_apool_lock( apool_t* apool ) { bli_pthread_mutex_lock( bli_apool_mutex( apool ) ); } -static void bli_apool_unlock( apool_t* apool ) +BLIS_INLINE void bli_apool_unlock( apool_t* apool ) { bli_pthread_mutex_unlock( bli_apool_mutex( apool ) ); } // apool entry modification -static void bli_apool_set_def_array_len( siz_t def_array_len, apool_t* pool ) \ +BLIS_INLINE void bli_apool_set_def_array_len( siz_t def_array_len, apool_t* pool ) \ { pool->def_array_len = def_array_len; } diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index 524340c5fe..004c906082 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,14 +33,34 @@ */ -#ifndef BLIS_CONFIGURETIME_CPUID - #include "blis.h" -#else +#ifdef BLIS_CONFIGURETIME_CPUID + + // NOTE: If you need to make any changes to this cpp branch, it's probably + // the case that you also need to modify bli_arch.c, bli_cpuid.c, and + // bli_env.c. Don't forget to update these other files as needed! + + // The BLIS_ENABLE_SYSTEM macro must be defined so that the correct cpp + // branch in bli_system.h is processed. (This macro is normally defined in + // bli_config.h.) + #define BLIS_ENABLE_SYSTEM + + // Use C-style static inline functions for any static inline functions that + // happen to be defined by the headers below. (This macro is normally defined + // in bli_config_macro_defs.h.) + #define BLIS_INLINE static + + // Since we're not building a shared library, we can forgo the use of the + // BLIS_EXPORT_BLIS annotations by #defining them to be nothing. (This macro + // is normally defined in bli_config_macro_defs.h.) #define BLIS_EXPORT_BLIS + #include "bli_system.h" #include "bli_type_defs.h" #include "bli_arch.h" #include "bli_cpuid.h" + #include "bli_env.h" +#else + #include "blis.h" #endif // ----------------------------------------------------------------------------- @@ -74,84 +94,157 @@ void bli_arch_set_id_once( void ) void bli_arch_set_id( void ) { - // Architecture families. -#if defined BLIS_FAMILY_INTEL64 || \ - defined BLIS_FAMILY_AMD64 || \ - defined BLIS_FAMILY_X86_64 || \ - defined BLIS_FAMILY_ARM64 || \ - defined BLIS_FAMILY_ARM32 - id = bli_cpuid_query_id(); -#endif + // Check the environment variable BLIS_ARCH_DEBUG to see if the user + // requested that we echo the result of the subconfiguration selection. + bool do_logging = bli_env_get_var( "BLIS_ARCH_DEBUG", 0 ); + bli_arch_set_logging( do_logging ); - // Intel microarchitectures. -#ifdef BLIS_FAMILY_SKX - id = BLIS_ARCH_SKX; -#endif -#ifdef BLIS_FAMILY_KNL - id = BLIS_ARCH_KNL; -#endif -#ifdef BLIS_FAMILY_KNC - id = BLIS_ARCH_KNC; -#endif -#ifdef BLIS_FAMILY_HASWELL - id = BLIS_ARCH_HASWELL; -#endif -#ifdef BLIS_FAMILY_SANDYBRIDGE - id = BLIS_ARCH_SANDYBRIDGE; -#endif -#ifdef BLIS_FAMILY_PENRYN - id = BLIS_ARCH_PENRYN; -#endif + // Check the environment variable BLIS_ARCH_TYPE to see if the user + // requested that we use a specific subconfiguration. + dim_t req_id = bli_env_get_var( "BLIS_ARCH_TYPE", -1 ); - // AMD microarchitectures. -#ifdef BLIS_FAMILY_ZEN - id = BLIS_ARCH_ZEN; -#endif -#ifdef BLIS_FAMILY_EXCAVATOR - id = BLIS_ARCH_EXCAVATOR; -#endif -#ifdef BLIS_FAMILY_STEAMROLLER - id = BLIS_ARCH_STEAMROLLER; -#endif -#ifdef BLIS_FAMILY_PILEDRIVER - id = BLIS_ARCH_PILEDRIVER; -#endif -#ifdef BLIS_FAMILY_BULLDOZER - id = BLIS_ARCH_BULLDOZER; -#endif +#ifndef BLIS_CONFIGURETIME_CPUID + if ( req_id != -1 ) + { + // BLIS_ARCH_TYPE was set. Cautiously check whether its value is usable. - // ARM microarchitectures. -#ifdef BLIS_FAMILY_THUNDERX2 - id = BLIS_ARCH_THUNDERX2; -#endif -#ifdef BLIS_FAMILY_CORTEXA57 - id = BLIS_ARCH_CORTEXA57; -#endif -#ifdef BLIS_FAMILY_CORTEXA53 - id = BLIS_ARCH_CORTEXA53; -#endif -#ifdef BLIS_FAMILY_CORTEXA15 - id = BLIS_ARCH_CORTEXA15; -#endif -#ifdef BLIS_FAMILY_CORTEXA9 - id = BLIS_ARCH_CORTEXA9; -#endif + // If req_id was set to an invalid arch_t value (ie: outside the range + // [0,BLIS_NUM_ARCHS-1]), output an error message and abort. + if ( bli_error_checking_is_enabled() ) + { + err_t e_val = bli_check_valid_arch_id( req_id ); + bli_check_error_code( e_val ); + } - // IBM microarchitectures. -#ifdef BLIS_FAMILY_POWER7 - id = BLIS_ARCH_POWER7; -#endif -#ifdef BLIS_FAMILY_BGQ - id = BLIS_ARCH_BGQ; -#endif -#ifdef BLIS_FAMILY_POWER9 - id = BLIS_ARCH_POWER9; -#endif + // At this point, we know that req_id is in the valid range, but we + // don't yet know if it refers to a context that was actually + // initialized. Query the address of an internal context data structure + // corresponding to req_id. This pointer will be NULL if the associated + // subconfig is not available. + cntx_t** req_cntx = bli_gks_lookup_id( req_id ); - // Generic microarchitecture. -#ifdef BLIS_FAMILY_GENERIC - id = BLIS_ARCH_GENERIC; + // This function checks the context pointer and aborts with a useful + // error message if the pointer is found to be NULL. + if ( bli_error_checking_is_enabled() ) + { + err_t e_val = bli_check_initialized_gks_cntx( req_cntx ); + bli_check_error_code( e_val ); + } + + // Finally, we can be confident that req_id (1) is in range and (2) + // refers to a context that has been initialized. + id = req_id; + } + else #endif + { + // BLIS_ARCH_TYPE was unset. Proceed with normal subconfiguration + // selection behavior. + + // Architecture families. + #if defined BLIS_FAMILY_INTEL64 || \ + defined BLIS_FAMILY_AMD64 || \ + defined BLIS_FAMILY_X86_64 || \ + defined BLIS_FAMILY_ARM64 || \ + defined BLIS_FAMILY_ARM32 || \ + defined BLIS_FAMILY_X86_64_NO_SKX || \ + defined BLIS_FAMILY_X86_64_NO_ZEN2 || \ + defined BLIS_FAMILY_X86_64_NO_ZEN3 + id = bli_cpuid_query_id(); + #endif + + // Intel microarchitectures. + #ifdef BLIS_FAMILY_SKX + id = BLIS_ARCH_SKX; + #endif + #ifdef BLIS_FAMILY_KNL + id = BLIS_ARCH_KNL; + #endif + #ifdef BLIS_FAMILY_KNC + id = BLIS_ARCH_KNC; + #endif + #ifdef BLIS_FAMILY_HASWELL + id = BLIS_ARCH_HASWELL; + #endif + #ifdef BLIS_FAMILY_SANDYBRIDGE + id = BLIS_ARCH_SANDYBRIDGE; + #endif + #ifdef BLIS_FAMILY_PENRYN + id = BLIS_ARCH_PENRYN; + #endif + + // AMD microarchitectures. + #ifdef BLIS_FAMILY_ZEN3 + id = BLIS_ARCH_ZEN3; + #endif + #ifdef BLIS_FAMILY_ZEN2 + id = BLIS_ARCH_ZEN2; + #endif + #ifdef BLIS_FAMILY_ZEN + id = BLIS_ARCH_ZEN; + #endif + #ifdef BLIS_FAMILY_EXCAVATOR + id = BLIS_ARCH_EXCAVATOR; + #endif + #ifdef BLIS_FAMILY_STEAMROLLER + id = BLIS_ARCH_STEAMROLLER; + #endif + #ifdef BLIS_FAMILY_PILEDRIVER + id = BLIS_ARCH_PILEDRIVER; + #endif + #ifdef BLIS_FAMILY_BULLDOZER + id = BLIS_ARCH_BULLDOZER; + #endif + + // ARM microarchitectures. + #ifdef BLIS_FAMILY_ARMSVE + id = BLIS_ARCH_ARMSVE; + #endif + #ifdef BLIS_FAMILY_A64FX + id = BLIS_ARCH_A64FX; + #endif + #ifdef BLIS_FAMILY_FIRESTORM + id = BLIS_ARCH_FIRESTORM; + #endif + #ifdef BLIS_FAMILY_THUNDERX2 + id = BLIS_ARCH_THUNDERX2; + #endif + #ifdef BLIS_FAMILY_CORTEXA57 + id = BLIS_ARCH_CORTEXA57; + #endif + #ifdef BLIS_FAMILY_CORTEXA53 + id = BLIS_ARCH_CORTEXA53; + #endif + #ifdef BLIS_FAMILY_CORTEXA15 + id = BLIS_ARCH_CORTEXA15; + #endif + #ifdef BLIS_FAMILY_CORTEXA9 + id = BLIS_ARCH_CORTEXA9; + #endif + + // IBM microarchitectures. + #ifdef BLIS_FAMILY_POWER10 + id = BLIS_ARCH_POWER10; + #endif + #ifdef BLIS_FAMILY_POWER9 + id = BLIS_ARCH_POWER9; + #endif + #ifdef BLIS_FAMILY_POWER7 + id = BLIS_ARCH_POWER7; + #endif + #ifdef BLIS_FAMILY_BGQ + id = BLIS_ARCH_BGQ; + #endif + + // Generic microarchitecture. + #ifdef BLIS_FAMILY_GENERIC + id = BLIS_ARCH_GENERIC; + #endif + } + + if ( bli_arch_get_logging() ) + fprintf( stderr, "libblis: selecting sub-configuration '%s'.\n", + bli_arch_string( id ) ); //printf( "blis_arch_query_id(): id = %u\n", id ); //exit(1); @@ -172,22 +265,28 @@ static char* config_name[ BLIS_NUM_ARCHS ] = "sandybridge", "penryn", + "zen3", + "zen2", "zen", "excavator", "steamroller", "piledriver", "bulldozer", + "armsve", + "a64fx", + "firestorm", "thunderx2", "cortexa57", "cortexa53", "cortexa15", "cortexa9", + "power10", + "power9", "power7", "bgq", - "power9", - + "generic" }; @@ -196,3 +295,37 @@ char* bli_arch_string( arch_t id ) return config_name[ id ]; } +// ----------------------------------------------------------------------------- + +static bool arch_dolog = 0; + +void bli_arch_set_logging( bool dolog ) +{ + arch_dolog = dolog; +} + +bool bli_arch_get_logging( void ) +{ + return arch_dolog; +} + +void bli_arch_log( char* fmt, ... ) +{ + char prefix[] = "libblis: "; + int n_chars = strlen( prefix ) + strlen( fmt ) + 1; + + if ( bli_arch_get_logging() && fmt ) + { + char* prefix_fmt = malloc( n_chars ); + + snprintf( prefix_fmt, n_chars, "%s%s", prefix, fmt ); + + va_list ap; + va_start( ap, fmt ); + vfprintf( stderr, prefix_fmt, ap ); + va_end( ap ); + + free( prefix_fmt ); + } +} + diff --git a/frame/base/bli_arch.h b/frame/base/bli_arch.h index 13c9c8aa67..0cd55dace3 100644 --- a/frame/base/bli_arch.h +++ b/frame/base/bli_arch.h @@ -35,13 +35,16 @@ #ifndef BLIS_ARCH_H #define BLIS_ARCH_H -arch_t bli_arch_query_id( void ); +BLIS_EXPORT_BLIS arch_t bli_arch_query_id( void ); -void bli_arch_set_id_once( void ); -void bli_arch_set_id( void ); +void bli_arch_set_id_once( void ); +void bli_arch_set_id( void ); -char* bli_arch_string( arch_t id ); +BLIS_EXPORT_BLIS char* bli_arch_string( arch_t id ); +void bli_arch_set_logging( bool dolog ); +bool bli_arch_get_logging( void ); +void bli_arch_log( char*, ... ); #endif diff --git a/frame/base/bli_array.c b/frame/base/bli_array.c index 3f167056e4..3844cd52f7 100644 --- a/frame/base/bli_array.c +++ b/frame/base/bli_array.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,6 +43,8 @@ void bli_array_init array_t* restrict array ) { + err_t r_val; + #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_array_init(): allocating array [%d * %d]: ", ( int )num_elem, ( int )elem_size ); @@ -52,7 +54,7 @@ void bli_array_init const size_t array_size = num_elem * elem_size; // Allocate the array buffer. - void* restrict buf = bli_malloc_intl( array_size ); + void* restrict buf = bli_malloc_intl( array_size, &r_val ); // Initialize the array elements to zero. THIS IS IMPORANT because // consumer threads will use the NULL-ness of the array elements to @@ -72,6 +74,8 @@ void bli_array_resize array_t* restrict array ) { + err_t r_val; + // Query the number of elements in the array. const siz_t num_elem_prev = bli_array_num_elem( array ); @@ -98,7 +102,7 @@ void bli_array_resize #endif // Allocate a new array buffer. - char* restrict buf_new = bli_malloc_intl( array_size_new ); + char* restrict buf_new = bli_malloc_intl( array_size_new, &r_val ); // Copy the previous array contents to the new array. memcpy( buf_new, buf_prev, array_size_prev ); diff --git a/frame/base/bli_array.h b/frame/base/bli_array.h index e3070ae67c..4cb00496b2 100644 --- a/frame/base/bli_array.h +++ b/frame/base/bli_array.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -51,34 +51,34 @@ typedef struct // Array entry query -static void* bli_array_buf( array_t* array ) +BLIS_INLINE void* bli_array_buf( array_t* array ) { return array->buf; } -static siz_t bli_array_num_elem( array_t* array ) +BLIS_INLINE siz_t bli_array_num_elem( array_t* array ) { return array->num_elem; } -static siz_t bli_array_elem_size( array_t* array ) +BLIS_INLINE siz_t bli_array_elem_size( array_t* array ) { return array->elem_size; } // Array entry modification -static void bli_array_set_buf( void* buf, array_t* array ) \ +BLIS_INLINE void bli_array_set_buf( void* buf, array_t* array ) \ { array->buf = buf; } -static void bli_array_set_num_elem( siz_t num_elem, array_t* array ) \ +BLIS_INLINE void bli_array_set_num_elem( siz_t num_elem, array_t* array ) \ { array->num_elem = num_elem; } -static void bli_array_set_elem_size( siz_t elem_size, array_t* array ) \ +BLIS_INLINE void bli_array_set_elem_size( siz_t elem_size, array_t* array ) \ { array->elem_size = elem_size; } diff --git a/frame/base/bli_auxinfo.h b/frame/base/bli_auxinfo.h index d598b9f965..d8c6cbb13f 100644 --- a/frame/base/bli_auxinfo.h +++ b/frame/base/bli_auxinfo.h @@ -38,81 +38,103 @@ // auxinfo_t field query -static pack_t bli_auxinfo_schema_a( auxinfo_t* ai ) +BLIS_INLINE pack_t bli_auxinfo_schema_a( auxinfo_t* ai ) { return ai->schema_a; } -static pack_t bli_auxinfo_schema_b( auxinfo_t* ai ) +BLIS_INLINE pack_t bli_auxinfo_schema_b( auxinfo_t* ai ) { return ai->schema_b; } -static void* bli_auxinfo_next_a( auxinfo_t* ai ) +BLIS_INLINE void* bli_auxinfo_next_a( auxinfo_t* ai ) { return ai->a_next; } -static void* bli_auxinfo_next_b( auxinfo_t* ai ) +BLIS_INLINE void* bli_auxinfo_next_b( auxinfo_t* ai ) { return ai->b_next; } -static inc_t bli_auxinfo_is_a( auxinfo_t* ai ) +BLIS_INLINE inc_t bli_auxinfo_is_a( auxinfo_t* ai ) { return ai->is_a; } -static inc_t bli_auxinfo_is_b( auxinfo_t* ai ) +BLIS_INLINE inc_t bli_auxinfo_is_b( auxinfo_t* ai ) { return ai->is_b; } -#if 0 -static inc_t bli_auxinfo_dt_on_output( auxinfo_t* ai ) +BLIS_INLINE inc_t bli_auxinfo_ps_a( auxinfo_t* ai ) { - return ai->dt_on_output; + return ai->ps_a; +} +BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai ) +{ + return ai->ps_b; +} + +BLIS_INLINE void_fp bli_auxinfo_ukr( auxinfo_t* ai ) +{ + return ai->ukr; +} +BLIS_INLINE void* bli_auxinfo_params( auxinfo_t* ai ) +{ + return ai->params; } -#endif // auxinfo_t field modification -static void bli_auxinfo_set_schema_a( pack_t schema, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_schema_a( pack_t schema, auxinfo_t* ai ) { ai->schema_a = schema; } -static void bli_auxinfo_set_schema_b( pack_t schema, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_schema_b( pack_t schema, auxinfo_t* ai ) { ai->schema_b = schema; } -static void bli_auxinfo_set_next_a( void* p, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_next_a( void* p, auxinfo_t* ai ) { ai->a_next = p; } -static void bli_auxinfo_set_next_b( void* p, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_next_b( void* p, auxinfo_t* ai ) { ai->b_next = p; } -static void bli_auxinfo_set_next_ab( void* ap, void* bp, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_next_ab( void* ap, void* bp, auxinfo_t* ai ) { ai->a_next = ap; ai->b_next = bp; } -static void bli_auxinfo_set_is_a( inc_t is, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_is_a( inc_t is, auxinfo_t* ai ) { ai->is_a = is; } -static void bli_auxinfo_set_is_b( inc_t is, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_is_b( inc_t is, auxinfo_t* ai ) { ai->is_b = is; } -#if 0 -static void bli_auxinfo_set_dt_on_output( num_t dt_on_output, auxinfo_t* ai ) +BLIS_INLINE void bli_auxinfo_set_ps_a( inc_t ps, auxinfo_t* ai ) { - ai->dt_on_output = dt_on_output; + ai->ps_a = ps; +} +BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai ) +{ + ai->ps_b = ps; +} + +BLIS_INLINE void bli_auxinfo_set_ukr( void_fp ukr, auxinfo_t* ai ) +{ + ai->ukr = ukr; +} +BLIS_INLINE void bli_auxinfo_set_params( void* params, auxinfo_t* ai ) +{ + ai->params = params; } -#endif -#endif +#endif diff --git a/frame/base/bli_blksz.c b/frame/base/bli_blksz.c index c6107ca809..524653d743 100644 --- a/frame/base/bli_blksz.c +++ b/frame/base/bli_blksz.c @@ -42,7 +42,9 @@ blksz_t* bli_blksz_create_ed dim_t b_z, dim_t be_z ) { - blksz_t* b = bli_malloc_intl( sizeof( blksz_t ) ); + err_t r_val; + + blksz_t* b = bli_malloc_intl( sizeof( blksz_t ), &r_val ); bli_blksz_init_ed ( @@ -62,7 +64,9 @@ blksz_t* bli_blksz_create dim_t be_s, dim_t be_d, dim_t be_c, dim_t be_z ) { - blksz_t* b = bli_malloc_intl( sizeof( blksz_t ) ); + err_t r_val; + + blksz_t* b = bli_malloc_intl( sizeof( blksz_t ), &r_val ); bli_blksz_init ( diff --git a/frame/base/bli_blksz.h b/frame/base/bli_blksz.h index a3400b2faf..2e0fefeae9 100644 --- a/frame/base/bli_blksz.h +++ b/frame/base/bli_blksz.h @@ -34,7 +34,7 @@ // blksz_t query -static dim_t bli_blksz_get_def +BLIS_INLINE dim_t bli_blksz_get_def ( num_t dt, blksz_t* b @@ -43,7 +43,7 @@ static dim_t bli_blksz_get_def return b->v[ dt ]; } -static dim_t bli_blksz_get_max +BLIS_INLINE dim_t bli_blksz_get_max ( num_t dt, blksz_t* b @@ -55,7 +55,7 @@ static dim_t bli_blksz_get_max // blksz_t modification -static void bli_blksz_set_def +BLIS_INLINE void bli_blksz_set_def ( dim_t val, num_t dt, @@ -65,7 +65,7 @@ static void bli_blksz_set_def b->v[ dt ] = val; } -static void bli_blksz_set_max +BLIS_INLINE void bli_blksz_set_max ( dim_t val, num_t dt, @@ -75,7 +75,7 @@ static void bli_blksz_set_max b->e[ dt ] = val; } -static void bli_blksz_copy +BLIS_INLINE void bli_blksz_copy ( blksz_t* b_src, blksz_t* b_dst @@ -84,7 +84,7 @@ static void bli_blksz_copy *b_dst = *b_src; } -static void bli_blksz_copy_if_pos +BLIS_INLINE void bli_blksz_copy_if_pos ( blksz_t* b_src, blksz_t* b_dst @@ -114,7 +114,7 @@ static void bli_blksz_copy_if_pos if ( e_z > 0 ) bli_blksz_set_max( e_z, BLIS_DCOMPLEX, b_dst ); } -static void bli_blksz_copy_def_dt +BLIS_INLINE void bli_blksz_copy_def_dt ( num_t dt_src, blksz_t* b_src, num_t dt_dst, blksz_t* b_dst @@ -125,7 +125,7 @@ static void bli_blksz_copy_def_dt bli_blksz_set_def( val, dt_dst, b_dst ); } -static void bli_blksz_copy_max_dt +BLIS_INLINE void bli_blksz_copy_max_dt ( num_t dt_src, blksz_t* b_src, num_t dt_dst, blksz_t* b_dst @@ -136,7 +136,7 @@ static void bli_blksz_copy_max_dt bli_blksz_set_max( val, dt_dst, b_dst ); } -static void bli_blksz_copy_dt +BLIS_INLINE void bli_blksz_copy_dt ( num_t dt_src, blksz_t* b_src, num_t dt_dst, blksz_t* b_dst @@ -146,7 +146,7 @@ static void bli_blksz_copy_dt bli_blksz_copy_max_dt( dt_src, b_src, dt_dst, b_dst ); } -static void bli_blksz_scale_def +BLIS_INLINE void bli_blksz_scale_def ( dim_t num, dim_t den, @@ -159,7 +159,7 @@ static void bli_blksz_scale_def bli_blksz_set_def( ( val * num ) / den, dt, b ); } -static void bli_blksz_scale_max +BLIS_INLINE void bli_blksz_scale_max ( dim_t num, dim_t den, @@ -172,7 +172,7 @@ static void bli_blksz_scale_max bli_blksz_set_max( ( val * num ) / den, dt, b ); } -static void bli_blksz_scale_def_max +BLIS_INLINE void bli_blksz_scale_def_max ( dim_t num, dim_t den, diff --git a/frame/base/bli_check.c b/frame/base/bli_check.c index f5b3aebec1..e76314036f 100644 --- a/frame/base/bli_check.c +++ b/frame/base/bli_check.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -819,22 +819,26 @@ err_t bli_check_if_exhausted_pool( pool_t* pool ) return e_val; } -err_t bli_check_sufficient_stack_buf_size( num_t dt, cntx_t* cntx ) +err_t bli_check_sufficient_stack_buf_size( cntx_t* cntx ) { err_t e_val = BLIS_SUCCESS; + num_t dt; - dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); - dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); - siz_t dt_size = bli_dt_size( dt ); - - // NOTE: For induced methods, we use the size of the complex datatypes - // (rather than the size of the native micro-kernels' datatype) because - // the macro-kernel needs this larger micro-tile footprint, even if the - // virtual micro-kernel implementation will only ever be writing to half - // of it (real or imaginary part) at a time. - - if ( mr * nr * dt_size > BLIS_STACK_BUF_MAX_SIZE ) - e_val = BLIS_INSUFFICIENT_STACK_BUF_SIZE; + for ( dt = BLIS_DT_LO; dt <= BLIS_DT_HI; ++dt ) + { + dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + siz_t dt_size = bli_dt_size( dt ); + + // NOTE: For induced methods, we use the size of the complex datatypes + // (rather than the size of the native micro-kernels' datatype) because + // the macro-kernel needs this larger micro-tile footprint, even if the + // virtual micro-kernel implementation will only ever be writing to half + // of it (real or imaginary part) at a time. + + if ( mr * nr * dt_size > BLIS_STACK_BUF_MAX_SIZE ) + e_val = BLIS_INSUFFICIENT_STACK_BUF_SIZE; + } return e_val; } @@ -891,6 +895,16 @@ err_t bli_check_valid_arch_id( arch_t id ) return e_val; } +err_t bli_check_initialized_gks_cntx( cntx_t** cntx ) +{ + err_t e_val = BLIS_SUCCESS; + + if ( cntx == NULL ) + e_val = BLIS_UNINITIALIZED_GKS_CNTX; + + return e_val; +} + // -- Architecture-related errors ---------------------------------------------- err_t bli_check_valid_mc_mod_mult( blksz_t* mc, blksz_t* mr ) diff --git a/frame/base/bli_check.h b/frame/base/bli_check.h index 242dc94872..276d276897 100644 --- a/frame/base/bli_check.h +++ b/frame/base/bli_check.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,7 +34,7 @@ */ -err_t bli_check_error_code_helper( gint_t code, char* file, guint_t line ); +BLIS_EXPORT_BLIS err_t bli_check_error_code_helper( gint_t code, char* file, guint_t line ); err_t bli_check_valid_error_level( errlev_t level ); @@ -103,13 +103,14 @@ err_t bli_check_valid_malloc_buf( void* ptr ); err_t bli_check_valid_packbuf( packbuf_t buf_type ); err_t bli_check_if_exhausted_pool( pool_t* pool ); -err_t bli_check_sufficient_stack_buf_size( num_t dt, cntx_t* cntx ); +err_t bli_check_sufficient_stack_buf_size( cntx_t* cntx ); err_t bli_check_alignment_is_power_of_two( size_t align_size ); err_t bli_check_alignment_is_mult_of_ptr_size( size_t align_size ); err_t bli_check_object_alias_of( obj_t* a, obj_t* b ); err_t bli_check_valid_arch_id( arch_t id ); +err_t bli_check_initialized_gks_cntx( cntx_t** cntx ); err_t bli_check_valid_mc_mod_mult( blksz_t* mc, blksz_t* mr ); err_t bli_check_valid_nc_mod_mult( blksz_t* nc, blksz_t* nr ); diff --git a/frame/base/bli_clock.c b/frame/base/bli_clock.c index bd5cd9e82a..62ccee1708 100644 --- a/frame/base/bli_clock.c +++ b/frame/base/bli_clock.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -64,6 +64,18 @@ double bli_clock_min_diff( double time_min, double time_start ) return time_min; } +#ifdef BLIS_DISABLE_SYSTEM +// --- Begin systemless definitions -------------------------------------------- + +double bli_clock_helper() +{ + return 0.0; +} + +// --- End systemless definitions ---------------------------------------------- +#else +// --- Begin system definitions ------------------------------------------------ + #if BLIS_OS_WINDOWS // --- Begin Windows build definitions ----------------------------------------- @@ -135,3 +147,6 @@ double bli_clock_helper() // --- End Linux build definitions --------------------------------------------- #endif +// --- End system definitions -------------------------------------------------- +#endif + diff --git a/frame/base/bli_cntl.c b/frame/base/bli_cntl.c index e24e69125d..f8846198f1 100644 --- a/frame/base/bli_cntl.c +++ b/frame/base/bli_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,7 @@ cntl_t* bli_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, void* params, cntl_t* sub_node ) @@ -192,7 +192,7 @@ void bli_cntl_free_w_thrinfo printf( "bli_cntl_free_w_thrinfo(): releasing mem pool block.\n" ); #endif - bli_membrk_release( rntm, cntl_pack_mem ); + bli_pba_release( rntm, cntl_pack_mem ); } // Free the current node. @@ -236,7 +236,7 @@ void bli_cntl_free_wo_thrinfo // allocated. if ( bli_mem_is_alloc( cntl_pack_mem ) ) { - bli_membrk_release( rntm, cntl_pack_mem ); + bli_pba_release( rntm, cntl_pack_mem ); } // Free the current node. @@ -360,7 +360,7 @@ dim_t bli_cntl_calc_num_threads_in bszid_t bszid = bli_cntl_bszid( cntl ); dim_t cur_way; - // We assume bszid is in {KR,MR,NR,MC,KC,NR} if it is not + // We assume bszid is in {NC,KC,MC,NR,MR,KR} if it is not // BLIS_NO_PART. if ( bszid != BLIS_NO_PART ) cur_way = bli_rntm_ways_for( bszid, rntm ); diff --git a/frame/base/bli_cntl.h b/frame/base/bli_cntl.h index 1959d3a3d6..67dd02f0c1 100644 --- a/frame/base/bli_cntl.h +++ b/frame/base/bli_cntl.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ struct cntl_s // Basic fields (usually required). opid_t family; bszid_t bszid; - void* var_func; + void_fp var_func; struct cntl_s* sub_prenode; struct cntl_s* sub_node; @@ -65,7 +65,7 @@ BLIS_EXPORT_BLIS cntl_t* bli_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, void* params, cntl_t* sub_node ); @@ -127,100 +127,100 @@ dim_t bli_cntl_calc_num_threads_in // cntl_t query (fields only) -static opid_t bli_cntl_family( cntl_t* cntl ) +BLIS_INLINE opid_t bli_cntl_family( cntl_t* cntl ) { return cntl->family; } -static bszid_t bli_cntl_bszid( cntl_t* cntl ) +BLIS_INLINE bszid_t bli_cntl_bszid( cntl_t* cntl ) { return cntl->bszid; } -static void* bli_cntl_var_func( cntl_t* cntl ) +BLIS_INLINE void_fp bli_cntl_var_func( cntl_t* cntl ) { return cntl->var_func; } -static cntl_t* bli_cntl_sub_prenode( cntl_t* cntl ) +BLIS_INLINE cntl_t* bli_cntl_sub_prenode( cntl_t* cntl ) { return cntl->sub_prenode; } -static cntl_t* bli_cntl_sub_node( cntl_t* cntl ) +BLIS_INLINE cntl_t* bli_cntl_sub_node( cntl_t* cntl ) { return cntl->sub_node; } -static void* bli_cntl_params( cntl_t* cntl ) +BLIS_INLINE void* bli_cntl_params( cntl_t* cntl ) { return cntl->params; } -static uint64_t bli_cntl_params_size( cntl_t* cntl ) +BLIS_INLINE uint64_t bli_cntl_params_size( cntl_t* cntl ) { // The first 64 bytes is always the size of the params structure. return *( ( uint64_t* )(cntl->params) ); } -static mem_t* bli_cntl_pack_mem( cntl_t* cntl ) +BLIS_INLINE mem_t* bli_cntl_pack_mem( cntl_t* cntl ) { return &(cntl->pack_mem); } // cntl_t query (complex) -static bool_t bli_cntl_is_null( cntl_t* cntl ) +BLIS_INLINE bool bli_cntl_is_null( cntl_t* cntl ) { - return ( bool_t ) + return ( bool ) ( cntl == NULL ); } -static bool_t bli_cntl_is_leaf( cntl_t* cntl ) +BLIS_INLINE bool bli_cntl_is_leaf( cntl_t* cntl ) { - return ( bool_t ) + return ( bool ) ( bli_cntl_sub_node( cntl ) == NULL ); } -static bool_t bli_cntl_does_part( cntl_t* cntl ) +BLIS_INLINE bool bli_cntl_does_part( cntl_t* cntl ) { - return ( bool_t ) + return ( bool ) ( bli_cntl_bszid( cntl ) != BLIS_NO_PART ); } // cntl_t modification -static void bli_cntl_set_family( opid_t family, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_family( opid_t family, cntl_t* cntl ) { cntl->family = family; } -static void bli_cntl_set_bszid( bszid_t bszid, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_bszid( bszid_t bszid, cntl_t* cntl ) { cntl->bszid = bszid; } -static void bli_cntl_set_var_func( void* var_func, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_var_func( void_fp var_func, cntl_t* cntl ) { cntl->var_func = var_func; } -static void bli_cntl_set_sub_prenode( cntl_t* sub_prenode, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_sub_prenode( cntl_t* sub_prenode, cntl_t* cntl ) { cntl->sub_prenode = sub_prenode; } -static void bli_cntl_set_sub_node( cntl_t* sub_node, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_sub_node( cntl_t* sub_node, cntl_t* cntl ) { cntl->sub_node = sub_node; } -static void bli_cntl_set_params( void* params, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_params( void* params, cntl_t* cntl ) { cntl->params = params; } -static void bli_cntl_set_pack_mem( mem_t* pack_mem, cntl_t* cntl ) +BLIS_INLINE void bli_cntl_set_pack_mem( mem_t* pack_mem, cntl_t* cntl ) { cntl->pack_mem = *pack_mem; } diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 580bb49055..3a698871b1 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -48,9 +48,8 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default blocksizes. It should be called after - // bli_cntx_init_defaults() so that default blocksizes remain - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default + // blocksizes across all datatypes. /* Example prototypes: @@ -76,49 +75,37 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) cntx_t* cntx ); */ + va_list args; dim_t i; - - bszid_t* bszids; - blksz_t** blkszs; - bszid_t* bmults; - double* dsclrs; - double* msclrs; - - cntx_t* cntx; - - blksz_t* cntx_blkszs; - bszid_t* cntx_bmults; - + err_t r_val; // Allocate some temporary local arrays. - - #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + bszid_t* bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - blkszs = bli_malloc_intl( n_bs * sizeof( blksz_t* ) ); + blksz_t** blkszs = bli_malloc_intl( n_bs * sizeof( blksz_t* ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - bmults = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + bszid_t* bmults = bli_malloc_intl( n_bs * sizeof( bszid_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - dsclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* dsclrs = bli_malloc_intl( n_bs * sizeof( double ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - msclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* msclrs = bli_malloc_intl( n_bs * sizeof( double ), &r_val ); // -- Begin variable argument section -- @@ -175,7 +162,7 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) } // The last argument should be the context pointer. - cntx = ( cntx_t* )va_arg( args, cntx_t* ); + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); // Shutdown variable argument environment and clean up stack. va_end( args ); @@ -188,8 +175,9 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // Query the context for the addresses of: // - the blocksize object array // - the blocksize multiple array - cntx_blkszs = bli_cntx_blkszs_buf( cntx ); - cntx_bmults = bli_cntx_bmults_buf( cntx ); + + blksz_t* cntx_blkszs = bli_cntx_blkszs_buf( cntx ); + bszid_t* cntx_bmults = bli_cntx_bmults_buf( cntx ); // Now that we have the context address, we want to copy the values // from the temporary buffers into the corresponding buffers in the @@ -236,12 +224,6 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) double msclr = msclrs[ i ]; blksz_t* blksz = blkszs[ i ]; - // NOTE: This is a bug! We need to grab the actual blocksize - // multiple, which is not at blkszs[i], but rather somewhere else - // in the array. In order to fix this, you probably need to store - // the contents of blkszs (and all the other arrays) by bs_id - // rather than i in the first loop. - blksz_t* bmult = blkszs[ i ]; blksz_t* cntx_blksz = &cntx_blkszs[ bs_id ]; @@ -260,20 +242,6 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // blocksize object. bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_SCOMPLEX, cntx_blksz ); bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_DCOMPLEX, cntx_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_def_to( BLIS_FLOAT, bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_def_to( BLIS_DOUBLE, bmult, BLIS_DCOMPLEX, cntx_blksz ); - } } // Similarly, if the maximum blocksize scalar is non-unit, we need @@ -284,20 +252,6 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // blocksize object. bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_SCOMPLEX, cntx_blksz ); bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_DCOMPLEX, cntx_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_max_to( BLIS_FLOAT, bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_max_to( BLIS_DOUBLE, bmult, BLIS_DCOMPLEX, cntx_blksz ); - } } // Copy the blocksize multiple id into the context. @@ -335,13 +289,14 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // ----------------------------------------------------------------------------- -void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) +void bli_cntx_set_ind_blkszs( ind_t method, num_t dt, dim_t n_bs, ... ) { /* Example prototypes: void bli_gks_cntx_set_ind_blkszs ( ind_t method != BLIS_NAT, + num_t dt, dim_t n_bs, bszid_t bs0_id, dim_t def_scalr0, dim_t max_scalr0, bszid_t bs1_id, dim_t def_scalr1, dim_t max_scalr1, @@ -353,14 +308,13 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) NOTE: This function modifies an existing context that is presumed to have been initialized for native execution. */ + va_list args; dim_t i; + err_t r_val; - bszid_t* bszids; - double* dsclrs; - double* msclrs; - - cntx_t* cntx; + // Project the given datatype to the real domain. This will be used later on. + num_t dt_real = bli_dt_proj_to_real( dt ); // Return early if called with BLIS_NAT. if ( method == BLIS_NAT ) return; @@ -370,17 +324,17 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_ind_blkszs(): " ); #endif - bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + bszid_t* bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_ind_blkszs(): " ); #endif - dsclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* dsclrs = bli_malloc_intl( n_bs * sizeof( double ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_ind_blkszs(): " ); #endif - msclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* msclrs = bli_malloc_intl( n_bs * sizeof( double ), &r_val ); // -- Begin variable argument section -- @@ -408,7 +362,7 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) } // The last argument should be the context pointer. - cntx = ( cntx_t* )va_arg( args, cntx_t* ); + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); // Shutdown variable argument environment and clean up stack. va_end( args ); @@ -434,66 +388,31 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) //blksz_t* cntx_blksz = &cntx_blkszs[ bs_id ]; - // Query the blocksize multiple's blocksize id. - bszid_t bm_id = bli_cntx_get_bmult_id( bs_id, cntx ); - // Query the context for the blksz_t object assoicated with the // current blocksize id, and also query the object corresponding // to the blocksize multiple. blksz_t* cntx_blksz = bli_cntx_get_blksz( bs_id, cntx ); - blksz_t* cntx_bmult = bli_cntx_get_bmult( bs_id, cntx ); - // Copy the real domain values of the blksz_t object into the - // the complex domain slots of the same object. - bli_blksz_copy_dt( BLIS_FLOAT, cntx_blksz, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_copy_dt( BLIS_DOUBLE, cntx_blksz, BLIS_DCOMPLEX, cntx_blksz ); + // Copy the real domain value of the blksz_t object into the + // corresponding complex domain slot of the same object. + bli_blksz_copy_dt( dt_real, cntx_blksz, dt, cntx_blksz ); // If the default blocksize scalar is non-unit, we need to scale // the complex domain default blocksizes. if ( dsclr != 1.0 ) { - // Scale the complex domain default blocksize values in the - // blocksize object. - bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_DCOMPLEX, cntx_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_def_to( BLIS_FLOAT, cntx_bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_def_to( BLIS_DOUBLE, cntx_bmult, BLIS_DCOMPLEX, cntx_blksz ); - } + // Scale the default blocksize value corresponding to the given + // datatype. + bli_blksz_scale_def( 1, ( dim_t )dsclr, dt, cntx_blksz ); } // Similarly, if the maximum blocksize scalar is non-unit, we need // to scale the complex domain maximum blocksizes. if ( msclr != 1.0 ) { - // Scale the complex domain maximum blocksize values in the - // blocksize object. - bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_DCOMPLEX, cntx_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_max_to( BLIS_FLOAT, cntx_bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_max_to( BLIS_DOUBLE, cntx_bmult, BLIS_DCOMPLEX, cntx_blksz ); - } + // Scale the maximum blocksize value corresponding to the given + // datatype. + bli_blksz_scale_max( 1, ( dim_t )msclr, dt, cntx_blksz ); } } } @@ -523,46 +442,47 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default level-3 microkernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default + // microkernels across all datatypes. /* Example prototypes: void bli_cntx_set_l3_nat_ukrs ( dim_t n_ukrs, - l3ukr_t ukr0_id, num_t dt0, void* ukr0_fp, bool_t pref0, - l3ukr_t ukr1_id, num_t dt1, void* ukr1_fp, bool_t pref1, - l3ukr_t ukr2_id, num_t dt2, void* ukr2_fp, bool_t pref2, + l3ukr_t ukr0_id, num_t dt0, void_fp ukr0_fp, bool pref0, + l3ukr_t ukr1_id, num_t dt1, void_fp ukr1_fp, bool pref1, + l3ukr_t ukr2_id, num_t dt2, void_fp ukr2_fp, bool pref2, ... cntx_t* cntx ); */ + va_list args; dim_t i; + err_t r_val; // Allocate some temporary local arrays. #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_nat_ukrs(): " ); #endif - l3ukr_t* ukr_ids = bli_malloc_intl( n_ukrs * sizeof( l3ukr_t ) ); + l3ukr_t* ukr_ids = bli_malloc_intl( n_ukrs * sizeof( l3ukr_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_nat_ukrs(): " ); #endif - num_t* ukr_dts = bli_malloc_intl( n_ukrs * sizeof( num_t ) ); + num_t* ukr_dts = bli_malloc_intl( n_ukrs * sizeof( num_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_nat_ukrs(): " ); #endif - void** ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void* ) ); + void_fp* ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void_fp ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_nat_ukrs(): " ); #endif - bool_t* ukr_prefs = bli_malloc_intl( n_ukrs * sizeof( bool_t ) ); + bool* ukr_prefs = bli_malloc_intl( n_ukrs * sizeof( bool ), &r_val ); // -- Begin variable argument section -- @@ -578,7 +498,10 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // - the kernel function pointer, and // - the kernel function storage preference // that we need to store to the context. - // NOTE: The type that we pass into the va_arg() macro for the ukr + + // NOTE: Though bool_t is no longer used, the following comment is + // being kept for historical reasons. + // The type that we pass into the va_arg() macro for the ukr // preference matters. Using 'bool_t' may cause breakage on 64-bit // systems that define int as 32 bits and long int and pointers as // 64 bits. The problem is that TRUE or FALSE are defined as 1 and @@ -590,8 +513,8 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // within a bool_t afterwards. const l3ukr_t ukr_id = ( l3ukr_t )va_arg( args, l3ukr_t ); const num_t ukr_dt = ( num_t )va_arg( args, num_t ); - void* ukr_fp = ( void* )va_arg( args, void* ); - const bool_t ukr_pref = ( bool_t )va_arg( args, int ); + void_fp ukr_fp = ( void_fp )va_arg( args, void_fp ); + const bool ukr_pref = ( bool )va_arg( args, int ); // Store the values in our temporary arrays. ukr_ids[ i ] = ukr_id; @@ -623,12 +546,12 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_ukrs; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current ukernel id, ukernel datatype, ukernel function + // pointer, and ukernel preference. const l3ukr_t ukr_id = ukr_ids[ i ]; const num_t ukr_dt = ukr_dts[ i ]; - void* ukr_fp = ukr_fps[ i ]; - const bool_t ukr_pref = ukr_prefs[ i ]; + void_fp ukr_fp = ukr_fps[ i ]; + const bool ukr_pref = ukr_prefs[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -672,46 +595,659 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // ----------------------------------------------------------------------------- +void bli_cntx_set_l3_vir_ukrs( dim_t n_ukrs, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default level-3 virtual microkernels. It should be called after + // bli_cntx_init_defaults() so that the context begins with default + // microkernels across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_l3_vir_ukrs + ( + dim_t n_ukrs, + l3ukr_t ukr0_id, num_t dt0, void_fp ukr0_fp, + l3ukr_t ukr1_id, num_t dt1, void_fp ukr1_fp, + l3ukr_t ukr2_id, num_t dt2, void_fp ukr2_fp, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + err_t r_val; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_vir_ukrs(): " ); + #endif + l3ukr_t* ukr_ids = bli_malloc_intl( n_ukrs * sizeof( l3ukr_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_vir_ukrs(): " ); + #endif + num_t* ukr_dts = bli_malloc_intl( n_ukrs * sizeof( num_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_vir_ukrs(): " ); + #endif + void_fp* ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void_fp ), &r_val ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_ukrs ); + + // Process n_ukrs tuples. + for ( i = 0; i < n_ukrs; ++i ) + { + // Here, we query the variable argument list for: + // - the l3ukr_t of the kernel we're about to process, + // - the datatype of the kernel, and + // - the kernel function pointer. + // that we need to store to the context. + const l3ukr_t ukr_id = ( l3ukr_t )va_arg( args, l3ukr_t ); + const num_t ukr_dt = ( num_t )va_arg( args, num_t ); + void_fp ukr_fp = ( void_fp )va_arg( args, void_fp ); + + // Store the values in our temporary arrays. + ukr_ids[ i ] = ukr_id; + ukr_dts[ i ] = ukr_dt; + ukr_fps[ i ] = ukr_fp; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the l3 virtual ukernel func_t array + func_t* cntx_l3_vir_ukrs = bli_cntx_l3_vir_ukrs_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_ukrs; ++i ) + { + // Read the current ukernel id, ukernel datatype, ukernel function + // pointer, and ukernel preference. + const l3ukr_t ukr_id = ukr_ids[ i ]; + const num_t ukr_dt = ukr_dts[ i ]; + void_fp ukr_fp = ukr_fps[ i ]; + + // Index into the func_t and mbool_t for the current kernel id + // being processed. + func_t* vukrs = &cntx_l3_vir_ukrs[ ukr_id ]; + + // Store the ukernel function pointer and preference values into + // the context. Notice that we redundantly store the native + // ukernel address in both the native and virtual ukernel slots + // in the context. This is standard practice when creating a + // native context. (Induced method contexts will overwrite the + // virtual function pointer with the address of the appropriate + // virtual ukernel.) + bli_func_set_dt( ukr_fp, ukr_dt, vukrs ); + } + + // Free the temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_vir_ukrs(): " ); + #endif + bli_free_intl( ukr_ids ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_vir_ukrs(): " ); + #endif + bli_free_intl( ukr_dts ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_vir_ukrs(): " ); + #endif + bli_free_intl( ukr_fps ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_thresh( dim_t n_thresh, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default thresholds for small/unpacked matrix handling. It should + // be called after bli_cntx_init_defaults() so that the context begins + // with default thresholds. + + /* Example prototypes: + + void bli_cntx_set_l3_sup_thresh + ( + dim_t n_thresh, + threshid_t th0_id, blksz_t* blksz0, + threshid_t th1_id, blksz_t* blksz1, + ... + cntx_t* cntx + ); + + */ + + va_list args; + dim_t i; + err_t r_val; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + threshid_t* threshids = bli_malloc_intl( n_thresh * sizeof( threshid_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + blksz_t** threshs = bli_malloc_intl( n_thresh * sizeof( blksz_t* ), &r_val ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_thresh ); + + // Process n_thresh tuples. + for ( i = 0; i < n_thresh; ++i ) + { + // Here, we query the variable argument list for: + // - the threshid_t of the threshold we're about to process, + // - the address of the blksz_t object, + threshid_t th_id = ( threshid_t )va_arg( args, threshid_t ); + blksz_t* thresh = ( blksz_t* )va_arg( args, blksz_t* ); + + // Store the values in our temporary arrays. + threshids[ i ] = th_id; + threshs[ i ] = thresh; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the threshold array + blksz_t* cntx_threshs = bli_cntx_l3_sup_thresh_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. Notice that the blksz_t* pointers were saved, rather than + // the objects themselves, but we copy the contents of the objects + // when copying into the context. + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_thresh; ++i ) + { + // Read the current blocksize id, blksz_t* pointer, blocksize + // multiple id, and blocksize scalar. + threshid_t th_id = threshids[ i ]; + blksz_t* thresh = threshs[ i ]; + + blksz_t* cntx_thresh = &cntx_threshs[ th_id ]; + + // Copy the blksz_t object contents into the appropriate + // location within the context's blksz_t array. + //cntx_threshs[ th_id ] = *thresh; + //bli_blksz_copy( thresh, cntx_thresh ); + bli_blksz_copy_if_pos( thresh, cntx_thresh ); + } + + // Free the temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + bli_free_intl( threshs ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + bli_free_intl( threshids ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_handlers( dim_t n_ops, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default level-3 operation handler for small/unpacked matrices. It + // should be called after bli_cntx_init_defaults() so that the context + // begins with default sup handlers across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_l3_sup_handlers + ( + dim_t n_ops, + opid_t op0_id, void* handler0_fp, + opid_t op1_id, void* handler1_fp, + opid_t op2_id, void* handler2_fp, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + err_t r_val; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + opid_t* op_ids = bli_malloc_intl( n_ops * sizeof( opid_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + void** op_fps = bli_malloc_intl( n_ops * sizeof( void* ), &r_val ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_ops ); + + // Process n_ukrs tuples. + for ( i = 0; i < n_ops; ++i ) + { + // Here, we query the variable argument list for: + // - the opid_t of the operation we're about to process, + // - the sup handler function pointer + // that we need to store to the context. + const opid_t op_id = ( opid_t )va_arg( args, opid_t ); + void* op_fp = ( void* )va_arg( args, void* ); + + // Store the values in our temporary arrays. + op_ids[ i ] = op_id; + op_fps[ i ] = op_fp; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the l3 small/unpacked handlers array + void** cntx_l3_sup_handlers = bli_cntx_l3_sup_handlers_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. + + // Process each operation id tuple provided. + for ( i = 0; i < n_ops; ++i ) + { + // Read the current operation id and handler function pointer. + const opid_t op_id = op_ids[ i ]; + void* op_fp = op_fps[ i ]; + + // Store the sup handler function pointer into the slot for the + // specified operation id. + cntx_l3_sup_handlers[ op_id ] = op_fp; + } + + // Free the temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + bli_free_intl( op_ids ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + bli_free_intl( op_fps ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_blkszs( dim_t n_bs, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default l3 sup blocksizes. It should be called after + // bli_cntx_init_defaults() so that the context begins with default + // blocksizes across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_blkszs + ( + dim_t n_bs, + bszid_t bs0_id, blksz_t* blksz0, + bszid_t bs1_id, blksz_t* blksz1, + bszid_t bs2_id, blksz_t* blksz2, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + err_t r_val; + + // Allocate some temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bszid_t* bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + blksz_t** blkszs = bli_malloc_intl( n_bs * sizeof( blksz_t* ), &r_val ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_bs ); + + // Process n_bs tuples. + for ( i = 0; i < n_bs; ++i ) + { + // Here, we query the variable argument list for: + // - the bszid_t of the blocksize we're about to process, + // - the address of the blksz_t object. + bszid_t bs_id = ( bszid_t )va_arg( args, bszid_t ); + blksz_t* blksz = ( blksz_t* )va_arg( args, blksz_t* ); + + // Store the values in our temporary arrays. + bszids[ i ] = bs_id; + blkszs[ i ] = blksz; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the blocksize object array + blksz_t* cntx_l3_sup_blkszs = bli_cntx_l3_sup_blkszs_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. Notice that the blksz_t* pointers were saved, rather than + // the objects themselves, but we copy the contents of the objects + // when copying into the context. + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_bs; ++i ) + { + // Read the current blocksize id, blksz_t* pointer, blocksize + // multiple id, and blocksize scalar. + bszid_t bs_id = bszids[ i ]; + blksz_t* blksz = blkszs[ i ]; + + blksz_t* cntx_l3_sup_blksz = &cntx_l3_sup_blkszs[ bs_id ]; + + // Copy the blksz_t object contents into the appropriate + // location within the context's blksz_t array. + //cntx_l3_sup_blkszs[ bs_id ] = *blksz; + //bli_blksz_copy( blksz, cntx_l3_sup_blksz ); + bli_blksz_copy_if_pos( blksz, cntx_l3_sup_blksz ); + } + + // Free the temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bli_free_intl( blkszs ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bli_free_intl( bszids ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default level-3 microkernels for small/unpacked matrices. It + // should be called after bli_cntx_init_defaults() so that the context + // begins with default sup micro/millikernels across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_l3_sup_kers + ( + dim_t n_ukrs, + stor3_t stor_id0, num_t dt0, void* ukr0_fp, bool pref0, + stor3_t stor_id1, num_t dt1, void* ukr1_fp, bool pref1, + stor3_t stor_id2, num_t dt2, void* ukr2_fp, bool pref2, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + err_t r_val; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + stor3_t* st3_ids = bli_malloc_intl( n_ukrs * sizeof( stor3_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + num_t* ukr_dts = bli_malloc_intl( n_ukrs * sizeof( num_t ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + void** ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void* ), &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bool* ukr_prefs = bli_malloc_intl( n_ukrs * sizeof( bool ), &r_val ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_ukrs ); + + // Process n_ukrs tuples. + for ( i = 0; i < n_ukrs; ++i ) + { + // Here, we query the variable argument list for: + // - the stor3_t storage case being assigned to the kernel we're + // about to process, + // - the datatype of the kernel, + // - the kernel function pointer, and + // - the kernel function storage preference + // that we need to store to the context. + const stor3_t st3_id = ( stor3_t )va_arg( args, stor3_t ); + const num_t ukr_dt = ( num_t )va_arg( args, num_t ); + void* ukr_fp = ( void* )va_arg( args, void* ); + const bool ukr_pref = ( bool )va_arg( args, int ); + + // Store the values in our temporary arrays. + st3_ids[ i ] = st3_id; + ukr_dts[ i ] = ukr_dt; + ukr_fps[ i ] = ukr_fp; + ukr_prefs[ i ] = ukr_pref; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the l3 small/unpacked ukernel func_t array + // - the l3 small/unpacked ukernel preferences array + func_t* cntx_l3_sup_kers = bli_cntx_l3_sup_kers_buf( cntx ); + mbool_t* cntx_l3_sup_kers_prefs = bli_cntx_l3_sup_kers_prefs_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. + +#if 0 + dim_t sup_map[ BLIS_NUM_LEVEL3_SUP_UKRS ][2]; + + // Create the small/unpacked ukernel mappings: + // - rv -> rrr 0, rcr 2 + // - rg -> rrc 1, rcc 3 + // - cv -> ccr 6, ccc 7 + // - cg -> crr 4, crc 5 + // - rd -> rrc 1 + // - cd -> crc 5 + // - rc -> rcc 3 + // - cr -> crr 4 + // - gx -> xxx 8 + // NOTE: We only need to set one slot in the context l3_sup_kers array + // for the general-stride/generic ukernel type, but since the loop below + // needs to be set up to set two slots to accommodate the RV, RG, CV, and + // CG, ukernel types, we will just be okay with the GX ukernel being set + // redundantly. (The RD, CD, CR, and RC ukernel types are set redundantly + // for the same reason.) + sup_map[ BLIS_GEMMSUP_RV_UKR ][0] = BLIS_RRR; + sup_map[ BLIS_GEMMSUP_RV_UKR ][1] = BLIS_RCR; + sup_map[ BLIS_GEMMSUP_RG_UKR ][0] = BLIS_RRC; + sup_map[ BLIS_GEMMSUP_RG_UKR ][1] = BLIS_RCC; + sup_map[ BLIS_GEMMSUP_CV_UKR ][0] = BLIS_CCR; + sup_map[ BLIS_GEMMSUP_CV_UKR ][1] = BLIS_CCC; + sup_map[ BLIS_GEMMSUP_CG_UKR ][0] = BLIS_CRR; + sup_map[ BLIS_GEMMSUP_CG_UKR ][1] = BLIS_CRC; + + sup_map[ BLIS_GEMMSUP_RD_UKR ][0] = BLIS_RRC; + sup_map[ BLIS_GEMMSUP_RD_UKR ][1] = BLIS_RRC; + sup_map[ BLIS_GEMMSUP_CD_UKR ][0] = BLIS_CRC; + sup_map[ BLIS_GEMMSUP_CD_UKR ][1] = BLIS_CRC; + + sup_map[ BLIS_GEMMSUP_RC_UKR ][0] = BLIS_RCC; + sup_map[ BLIS_GEMMSUP_RC_UKR ][1] = BLIS_RCC; + sup_map[ BLIS_GEMMSUP_CR_UKR ][0] = BLIS_CRR; + sup_map[ BLIS_GEMMSUP_CR_UKR ][1] = BLIS_CRR; + + sup_map[ BLIS_GEMMSUP_GX_UKR ][0] = BLIS_XXX; + sup_map[ BLIS_GEMMSUP_GX_UKR ][1] = BLIS_XXX; +#endif + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_ukrs; ++i ) + { + // Read the current stor3_t id, ukernel datatype, ukernel function + // pointer, and ukernel preference. + const stor3_t st3_id = st3_ids[ i ]; + const num_t ukr_dt = ukr_dts[ i ]; + void* ukr_fp = ukr_fps[ i ]; + const bool ukr_pref = ukr_prefs[ i ]; + + // Index to the func_t and mbool_t for the current stor3_t id + // being processed. + func_t* ukrs = &cntx_l3_sup_kers[ st3_id ]; + mbool_t* prefs = &cntx_l3_sup_kers_prefs[ st3_id ]; + + // Store the ukernel function pointer and preference values into + // the stor3_t location in the context. + bli_func_set_dt( ukr_fp, ukr_dt, ukrs ); + bli_mbool_set_dt( ukr_pref, ukr_dt, prefs ); + } + + // Free the temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( st3_ids ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( ukr_dts ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( ukr_fps ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( ukr_prefs ); +} + +// ----------------------------------------------------------------------------- + void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) { // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default level-1f kernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default l1f + // kernels across all datatypes. /* Example prototypes: void bli_cntx_set_l1f_kers ( dim_t n_ukrs, - l1fkr_t ker0_id, num_t ker0_dt, void* ker0_fp, - l1fkr_t ker1_id, num_t ker1_dt, void* ker1_fp, - l1fkr_t ker2_id, num_t ker2_dt, void* ker2_fp, + l1fkr_t ker0_id, num_t ker0_dt, void_fp ker0_fp, + l1fkr_t ker1_id, num_t ker1_dt, void_fp ker1_fp, + l1fkr_t ker2_id, num_t ker2_dt, void_fp ker2_fp, ... cntx_t* cntx ); */ + va_list args; dim_t i; + err_t r_val; // Allocate some temporary local arrays. #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1f_kers(): " ); #endif - l1fkr_t* ker_ids = bli_malloc_intl( n_kers * sizeof( l1fkr_t ) ); + l1fkr_t* ker_ids = bli_malloc_intl( n_kers * sizeof( l1fkr_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1f_kers(): " ); #endif - num_t* ker_dts = bli_malloc_intl( n_kers * sizeof( num_t ) ); + num_t* ker_dts = bli_malloc_intl( n_kers * sizeof( num_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1f_kers(): " ); #endif - void** ker_fps = bli_malloc_intl( n_kers * sizeof( void* ) ); + void_fp* ker_fps = bli_malloc_intl( n_kers * sizeof( void_fp ), &r_val ); // -- Begin variable argument section -- @@ -728,7 +1264,7 @@ void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) // that we need to store to the context. const l1fkr_t ker_id = ( l1fkr_t )va_arg( args, l1fkr_t ); const num_t ker_dt = ( num_t )va_arg( args, num_t ); - void* ker_fp = ( void* )va_arg( args, void* ); + void_fp ker_fp = ( void_fp )va_arg( args, void_fp ); // Store the values in our temporary arrays. ker_ids[ i ] = ker_id; @@ -755,11 +1291,11 @@ void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_kers; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current kernel id, kernel datatype, and kernel function + // pointer. const l1fkr_t ker_id = ker_ids[ i ]; const num_t ker_dt = ker_dts[ i ]; - void* ker_fp = ker_fps[ i ]; + void_fp ker_fp = ker_fps[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -795,41 +1331,42 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default level-1v kernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default l1v + // kernels across all datatypes. /* Example prototypes: void bli_cntx_set_l1v_kers ( dim_t n_ukrs, - l1vkr_t ker0_id, num_t ker0_dt, void* ker0_fp, - l1vkr_t ker1_id, num_t ker1_dt, void* ker1_fp, - l1vkr_t ker2_id, num_t ker2_dt, void* ker2_fp, + l1vkr_t ker0_id, num_t ker0_dt, void_fp ker0_fp, + l1vkr_t ker1_id, num_t ker1_dt, void_fp ker1_fp, + l1vkr_t ker2_id, num_t ker2_dt, void_fp ker2_fp, ... cntx_t* cntx ); */ + va_list args; dim_t i; + err_t r_val; // Allocate some temporary local arrays. #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1v_kers(): " ); #endif - l1vkr_t* ker_ids = bli_malloc_intl( n_kers * sizeof( l1vkr_t ) ); + l1vkr_t* ker_ids = bli_malloc_intl( n_kers * sizeof( l1vkr_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1v_kers(): " ); #endif - num_t* ker_dts = bli_malloc_intl( n_kers * sizeof( num_t ) ); + num_t* ker_dts = bli_malloc_intl( n_kers * sizeof( num_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1v_kers(): " ); #endif - void** ker_fps = bli_malloc_intl( n_kers * sizeof( void* ) ); + void_fp* ker_fps = bli_malloc_intl( n_kers * sizeof( void_fp ), &r_val ); // -- Begin variable argument section -- @@ -846,7 +1383,7 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) // that we need to store to the context. const l1vkr_t ker_id = ( l1vkr_t )va_arg( args, l1vkr_t ); const num_t ker_dt = ( num_t )va_arg( args, num_t ); - void* ker_fp = ( void* )va_arg( args, void* ); + void_fp ker_fp = ( void_fp )va_arg( args, void_fp ); // Store the values in our temporary arrays. ker_ids[ i ] = ker_id; @@ -873,11 +1410,11 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_kers; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current kernel id, kernel datatype, and kernel function + // pointer. const l1vkr_t ker_id = ker_ids[ i ]; const num_t ker_dt = ker_dts[ i ]; - void* ker_fp = ker_fps[ i ]; + void_fp ker_fp = ker_fps[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -913,41 +1450,42 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default packing kernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default packm + // kernels across all datatypes. /* Example prototypes: void bli_cntx_set_packm_kers ( dim_t n_ukrs, - l1mkr_t ker0_id, num_t ker0_dt, void* ker0_fp, - l1mkr_t ker1_id, num_t ker1_dt, void* ker1_fp, - l1mkr_t ker2_id, num_t ker2_dt, void* ker2_fp, + l1mkr_t ker0_id, num_t ker0_dt, void_fp ker0_fp, + l1mkr_t ker1_id, num_t ker1_dt, void_fp ker1_fp, + l1mkr_t ker2_id, num_t ker2_dt, void_fp ker2_fp, ... cntx_t* cntx ); */ + va_list args; dim_t i; + err_t r_val; // Allocate some temporary local arrays. #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_packm_kers(): " ); #endif - l1mkr_t* ker_ids = bli_malloc_intl( n_kers * sizeof( l1mkr_t ) ); + l1mkr_t* ker_ids = bli_malloc_intl( n_kers * sizeof( l1mkr_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_packm_kers(): " ); #endif - num_t* ker_dts = bli_malloc_intl( n_kers * sizeof( num_t ) ); + num_t* ker_dts = bli_malloc_intl( n_kers * sizeof( num_t ), &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_packm_kers(): " ); #endif - void** ker_fps = bli_malloc_intl( n_kers * sizeof( void* ) ); + void_fp* ker_fps = bli_malloc_intl( n_kers * sizeof( void_fp ), &r_val ); // -- Begin variable argument section -- @@ -964,7 +1502,7 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) // that we need to store to the context. const l1mkr_t ker_id = ( l1mkr_t )va_arg( args, l1mkr_t ); const num_t ker_dt = ( num_t )va_arg( args, num_t ); - void* ker_fp = ( void* )va_arg( args, void* ); + void_fp ker_fp = ( void_fp )va_arg( args, void_fp ); // Store the values in our temporary arrays. ker_ids[ i ] = ker_id; @@ -991,11 +1529,11 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_kers; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current kernel id, kernel datatype, and kernel function + // pointer. const l1mkr_t ker_id = ker_ids[ i ]; const num_t ker_dt = ker_dts[ i ]; - void* ker_fp = ker_fps[ i ]; + void_fp ker_fp = ker_fps[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -1061,11 +1599,11 @@ void bli_cntx_print( cntx_t* cntx ) ); } - for ( i = 0; i < BLIS_NUM_LEVEL3_UKRS; ++i ) + for ( i = 0; i < BLIS_NUM_3OP_RC_COMBOS; ++i ) { - func_t* ukr = bli_cntx_get_l3_nat_ukrs( i, cntx ); + func_t* ukr = bli_cntx_get_l3_sup_kers( i, cntx ); - printf( "l3 nat ukr %2lu: %16p %16p %16p %16p\n", + printf( "l3 sup ukr %2lu: %16p %16p %16p %16p\n", ( unsigned long )i, bli_func_get_dt( BLIS_FLOAT, ukr ), bli_func_get_dt( BLIS_DOUBLE, ukr ), diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index e87794e905..76350f6bcf 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -6,6 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,6 +50,12 @@ typedef struct cntx_s func_t* l3_nat_ukrs; mbool_t* l3_nat_ukrs_prefs; + blksz_t* l3_sup_thresh; + void** l3_sup_handlers; + blksz_t* l3_sup_blkszs; + func_t* l3_sup_kers; + mbool_t* l3_sup_kers_prefs; + func_t* l1f_kers; func_t* l1v_kers; @@ -56,9 +63,6 @@ typedef struct cntx_s func_t* unpackm_kers; ind_t method; - pack_t schema_a; - pack_t schema_b; - pack_t schema_c; } cntx_t; */ @@ -69,57 +73,65 @@ typedef struct cntx_s // -- cntx_t query (fields only) ----------------------------------------------- // -static blksz_t* bli_cntx_blkszs_buf( cntx_t* cntx ) +BLIS_INLINE blksz_t* bli_cntx_blkszs_buf( cntx_t* cntx ) { return cntx->blkszs; } -static bszid_t* bli_cntx_bmults_buf( cntx_t* cntx ) +BLIS_INLINE bszid_t* bli_cntx_bmults_buf( cntx_t* cntx ) { return cntx->bmults; } -static func_t* bli_cntx_l3_vir_ukrs_buf( cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_l3_vir_ukrs_buf( cntx_t* cntx ) { return cntx->l3_vir_ukrs; } -static func_t* bli_cntx_l3_nat_ukrs_buf( cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_l3_nat_ukrs_buf( cntx_t* cntx ) { return cntx->l3_nat_ukrs; } -static mbool_t* bli_cntx_l3_nat_ukrs_prefs_buf( cntx_t* cntx ) +BLIS_INLINE mbool_t* bli_cntx_l3_nat_ukrs_prefs_buf( cntx_t* cntx ) { return cntx->l3_nat_ukrs_prefs; } -static func_t* bli_cntx_l1f_kers_buf( cntx_t* cntx ) +BLIS_INLINE blksz_t* bli_cntx_l3_sup_thresh_buf( cntx_t* cntx ) { - return cntx->l1f_kers; + return cntx->l3_sup_thresh; } -static func_t* bli_cntx_l1v_kers_buf( cntx_t* cntx ) +BLIS_INLINE void** bli_cntx_l3_sup_handlers_buf( cntx_t* cntx ) { - return cntx->l1v_kers; + return cntx->l3_sup_handlers; } -static func_t* bli_cntx_packm_kers_buf( cntx_t* cntx ) +BLIS_INLINE blksz_t* bli_cntx_l3_sup_blkszs_buf( cntx_t* cntx ) { - return cntx->packm_kers; + return cntx->l3_sup_blkszs; } -static func_t* bli_cntx_unpackm_kers_buf( cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_l3_sup_kers_buf( cntx_t* cntx ) { - return cntx->unpackm_kers; + return cntx->l3_sup_kers; } -static ind_t bli_cntx_method( cntx_t* cntx ) +BLIS_INLINE mbool_t* bli_cntx_l3_sup_kers_prefs_buf( cntx_t* cntx ) { - return cntx->method; + return cntx->l3_sup_kers_prefs; } -static pack_t bli_cntx_schema_a_block( cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_l1f_kers_buf( cntx_t* cntx ) { - return cntx->schema_a_block; + return cntx->l1f_kers; } -static pack_t bli_cntx_schema_b_panel( cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_l1v_kers_buf( cntx_t* cntx ) { - return cntx->schema_b_panel; + return cntx->l1v_kers; } -static pack_t bli_cntx_schema_c_panel( cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_packm_kers_buf( cntx_t* cntx ) { - return cntx->schema_c_panel; + return cntx->packm_kers; +} +BLIS_INLINE func_t* bli_cntx_unpackm_kers_buf( cntx_t* cntx ) +{ + return cntx->unpackm_kers; +} +BLIS_INLINE ind_t bli_cntx_method( cntx_t* cntx ) +{ + return cntx->method; } // ----------------------------------------------------------------------------- @@ -128,27 +140,10 @@ static pack_t bli_cntx_schema_c_panel( cntx_t* cntx ) // -- cntx_t modification (fields only) ---------------------------------------- // -static void bli_cntx_set_method( ind_t method, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_method( ind_t method, cntx_t* cntx ) { cntx->method = method; } -static void bli_cntx_set_schema_a_block( pack_t schema, cntx_t* cntx ) -{ - cntx->schema_a_block = schema; -} -static void bli_cntx_set_schema_b_panel( pack_t schema, cntx_t* cntx ) -{ - cntx->schema_b_panel = schema; -} -static void bli_cntx_set_schema_c_panel( pack_t schema, cntx_t* cntx ) -{ - cntx->schema_c_panel = schema; -} -static void bli_cntx_set_schema_ab_blockpanel( pack_t sa, pack_t sb, cntx_t* cntx ) -{ - bli_cntx_set_schema_a_block( sa, cntx ); - bli_cntx_set_schema_b_panel( sb, cntx ); -} // ----------------------------------------------------------------------------- @@ -156,7 +151,7 @@ static void bli_cntx_set_schema_ab_blockpanel( pack_t sa, pack_t sb, cntx_t* cnt // -- cntx_t query (complex) --------------------------------------------------- // -static blksz_t* bli_cntx_get_blksz( bszid_t bs_id, cntx_t* cntx ) +BLIS_INLINE blksz_t* bli_cntx_get_blksz( bszid_t bs_id, cntx_t* cntx ) { blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); blksz_t* blksz = &blkszs[ bs_id ]; @@ -165,7 +160,7 @@ static blksz_t* bli_cntx_get_blksz( bszid_t bs_id, cntx_t* cntx ) return blksz; } -static dim_t bli_cntx_get_blksz_def_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +BLIS_INLINE dim_t bli_cntx_get_blksz_def_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) { blksz_t* blksz = bli_cntx_get_blksz( bs_id, cntx ); dim_t bs_dt = bli_blksz_get_def( dt, blksz ); @@ -174,7 +169,7 @@ static dim_t bli_cntx_get_blksz_def_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) return bs_dt; } -static dim_t bli_cntx_get_blksz_max_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +BLIS_INLINE dim_t bli_cntx_get_blksz_max_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) { blksz_t* blksz = bli_cntx_get_blksz( bs_id, cntx ); dim_t bs_dt = bli_blksz_get_max( dt, blksz ); @@ -183,7 +178,7 @@ static dim_t bli_cntx_get_blksz_max_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) return bs_dt; } -static bszid_t bli_cntx_get_bmult_id( bszid_t bs_id, cntx_t* cntx ) +BLIS_INLINE bszid_t bli_cntx_get_bmult_id( bszid_t bs_id, cntx_t* cntx ) { bszid_t* restrict bmults = bli_cntx_bmults_buf( cntx ); bszid_t bm_id = bmults[ bs_id ]; @@ -191,7 +186,7 @@ static bszid_t bli_cntx_get_bmult_id( bszid_t bs_id, cntx_t* cntx ) return bm_id; } -static blksz_t* bli_cntx_get_bmult( bszid_t bs_id, cntx_t* cntx ) +BLIS_INLINE blksz_t* bli_cntx_get_bmult( bszid_t bs_id, cntx_t* cntx ) { bszid_t bm_id = bli_cntx_get_bmult_id( bs_id, cntx ); blksz_t* restrict bmult = bli_cntx_get_blksz( bm_id, cntx ); @@ -199,7 +194,7 @@ static blksz_t* bli_cntx_get_bmult( bszid_t bs_id, cntx_t* cntx ) return bmult; } -static dim_t bli_cntx_get_bmult_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +BLIS_INLINE dim_t bli_cntx_get_bmult_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) { blksz_t* bmult = bli_cntx_get_bmult( bs_id, cntx ); dim_t bm_dt = bli_blksz_get_def( dt, bmult ); @@ -209,7 +204,7 @@ static dim_t bli_cntx_get_bmult_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) // ----------------------------------------------------------------------------- -static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) { func_t* funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); func_t* func = &funcs[ ukr_id ]; @@ -217,14 +212,14 @@ static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l3_vir_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE void_fp bli_cntx_get_l3_vir_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l3_vir_ukrs( ukr_id, cntx ); return bli_func_get_dt( dt, func ); } -static func_t* bli_cntx_get_l3_nat_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_get_l3_nat_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) { func_t* funcs = bli_cntx_l3_nat_ukrs_buf( cntx ); func_t* func = &funcs[ ukr_id ]; @@ -232,7 +227,7 @@ static func_t* bli_cntx_get_l3_nat_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l3_nat_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE void_fp bli_cntx_get_l3_nat_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l3_nat_ukrs( ukr_id, cntx ); @@ -241,7 +236,7 @@ static void* bli_cntx_get_l3_nat_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx // ----------------------------------------------------------------------------- -static mbool_t* bli_cntx_get_l3_nat_ukr_prefs( l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE mbool_t* bli_cntx_get_l3_nat_ukr_prefs( l3ukr_t ukr_id, cntx_t* cntx ) { mbool_t* mbools = bli_cntx_l3_nat_ukrs_prefs_buf( cntx ); mbool_t* mbool = &mbools[ ukr_id ]; @@ -249,16 +244,118 @@ static mbool_t* bli_cntx_get_l3_nat_ukr_prefs( l3ukr_t ukr_id, cntx_t* cntx ) return mbool; } -static bool_t bli_cntx_get_l3_nat_ukr_prefs_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_get_l3_nat_ukr_prefs_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { mbool_t* mbool = bli_cntx_get_l3_nat_ukr_prefs( ukr_id, cntx ); - return bli_mbool_get_dt( dt, mbool ); + return ( bool )bli_mbool_get_dt( dt, mbool ); +} + +// ----------------------------------------------------------------------------- + +BLIS_INLINE blksz_t* bli_cntx_get_l3_sup_thresh( threshid_t thresh_id, cntx_t* cntx ) +{ + blksz_t* threshs = bli_cntx_l3_sup_thresh_buf( cntx ); + blksz_t* thresh = &threshs[ thresh_id ]; + + // Return the address of the blksz_t identified by thresh_id. + return thresh; +} + +BLIS_INLINE dim_t bli_cntx_get_l3_sup_thresh_dt( num_t dt, threshid_t thresh_id, cntx_t* cntx ) +{ + blksz_t* threshs = bli_cntx_get_l3_sup_thresh( thresh_id, cntx ); + dim_t thresh_dt = bli_blksz_get_def( dt, threshs ); + + // Return the main (default) threshold value for the datatype given. + return thresh_dt; +} + +BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( num_t dt, dim_t m, dim_t n, dim_t k, cntx_t* cntx ) +{ + if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE; + if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE; + if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE; + + return FALSE; +} + +// ----------------------------------------------------------------------------- + +BLIS_INLINE void* bli_cntx_get_l3_sup_handler( opid_t op, cntx_t* cntx ) +{ + void** funcs = bli_cntx_l3_sup_handlers_buf( cntx ); + void* func = funcs[ op ]; + + return func; +} + +// ----------------------------------------------------------------------------- + +BLIS_INLINE blksz_t* bli_cntx_get_l3_sup_blksz( bszid_t bs_id, cntx_t* cntx ) +{ + blksz_t* blkszs = bli_cntx_l3_sup_blkszs_buf( cntx ); + blksz_t* blksz = &blkszs[ bs_id ]; + + // Return the address of the blksz_t identified by bs_id. + return blksz; +} + +BLIS_INLINE dim_t bli_cntx_get_l3_sup_blksz_def_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +{ + blksz_t* blksz = bli_cntx_get_l3_sup_blksz( bs_id, cntx ); + dim_t bs_dt = bli_blksz_get_def( dt, blksz ); + + // Return the main (default) blocksize value for the datatype given. + return bs_dt; +} + +BLIS_INLINE dim_t bli_cntx_get_l3_sup_blksz_max_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +{ + blksz_t* blksz = bli_cntx_get_l3_sup_blksz( bs_id, cntx ); + dim_t bs_dt = bli_blksz_get_max( dt, blksz ); + + // Return the auxiliary (maximum) blocksize value for the datatype given. + return bs_dt; +} + +// ----------------------------------------------------------------------------- + +BLIS_INLINE func_t* bli_cntx_get_l3_sup_kers( stor3_t stor_id, cntx_t* cntx ) +{ + func_t* funcs = bli_cntx_l3_sup_kers_buf( cntx ); + func_t* func = &funcs[ stor_id ]; + + return func; +} + +BLIS_INLINE void* bli_cntx_get_l3_sup_ker_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + func_t* func = bli_cntx_get_l3_sup_kers( stor_id, cntx ); + + return bli_func_get_dt( dt, func ); } // ----------------------------------------------------------------------------- -static func_t* bli_cntx_get_l1f_kers( l1fkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE mbool_t* bli_cntx_get_l3_sup_ker_prefs( stor3_t stor_id, cntx_t* cntx ) +{ + mbool_t* mbools = bli_cntx_l3_sup_kers_prefs_buf( cntx ); + mbool_t* mbool = &mbools[ stor_id ]; + + return mbool; +} + +BLIS_INLINE bool bli_cntx_get_l3_sup_ker_prefs_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + mbool_t* mbool = bli_cntx_get_l3_sup_ker_prefs( stor_id, cntx ); + + return ( bool )bli_mbool_get_dt( dt, mbool ); +} + +// ----------------------------------------------------------------------------- + +BLIS_INLINE func_t* bli_cntx_get_l1f_kers( l1fkr_t ker_id, cntx_t* cntx ) { func_t* funcs = bli_cntx_l1f_kers_buf( cntx ); func_t* func = &funcs[ ker_id ]; @@ -266,7 +363,7 @@ static func_t* bli_cntx_get_l1f_kers( l1fkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l1f_ker_dt( num_t dt, l1fkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE void_fp bli_cntx_get_l1f_ker_dt( num_t dt, l1fkr_t ker_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l1f_kers( ker_id, cntx ); @@ -275,7 +372,7 @@ static void* bli_cntx_get_l1f_ker_dt( num_t dt, l1fkr_t ker_id, cntx_t* cntx ) // ----------------------------------------------------------------------------- -static func_t* bli_cntx_get_l1v_kers( l1vkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_get_l1v_kers( l1vkr_t ker_id, cntx_t* cntx ) { func_t* funcs = bli_cntx_l1v_kers_buf( cntx ); func_t* func = &funcs[ ker_id ]; @@ -283,7 +380,7 @@ static func_t* bli_cntx_get_l1v_kers( l1vkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l1v_ker_dt( num_t dt, l1vkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE void_fp bli_cntx_get_l1v_ker_dt( num_t dt, l1vkr_t ker_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l1v_kers( ker_id, cntx ); @@ -292,7 +389,7 @@ static void* bli_cntx_get_l1v_ker_dt( num_t dt, l1vkr_t ker_id, cntx_t* cntx ) // ----------------------------------------------------------------------------- -static func_t* bli_cntx_get_packm_kers( l1mkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_get_packm_kers( l1mkr_t ker_id, cntx_t* cntx ) { func_t* func = NULL; @@ -309,9 +406,9 @@ static func_t* bli_cntx_get_packm_kers( l1mkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE void_fp bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { - void* fp = NULL; + void_fp fp = NULL; // Only query the context for the packm func_t (and then extract the // datatype-specific function pointer) if the packm kernel being @@ -327,7 +424,7 @@ static void* bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) return fp; } -static func_t* bli_cntx_get_unpackm_kers( l1mkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE func_t* bli_cntx_get_unpackm_kers( l1mkr_t ker_id, cntx_t* cntx ) { func_t* func = NULL; @@ -344,9 +441,9 @@ static func_t* bli_cntx_get_unpackm_kers( l1mkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_unpackm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE void_fp bli_cntx_get_unpackm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { - void* fp = NULL; + void_fp fp = NULL; // Only query the context for the unpackm func_t (and then extract the // datatype-specific function pointer) if the unpackm kernel being @@ -364,34 +461,34 @@ static void* bli_cntx_get_unpackm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx // ----------------------------------------------------------------------------- -static bool_t bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { - bool_t prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); + const bool prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); // A ukernel preference of TRUE means the ukernel prefers row storage. - return ( bool_t ) + return ( bool ) ( prefs == TRUE ); } -static bool_t bli_cntx_l3_nat_ukr_prefers_cols_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_nat_ukr_prefers_cols_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { - bool_t prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); + const bool prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); // A ukernel preference of FALSE means the ukernel prefers column storage. - return ( bool_t ) + return ( bool ) ( prefs == FALSE ); } -static bool_t bli_cntx_l3_nat_ukr_prefers_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_nat_ukr_prefers_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) { // Note that we use the computation datatype, which may differ from the // storage datatype of C (when performing a mixed datatype operation). - const num_t dt = bli_obj_comp_dt( obj ); - const bool_t ukr_prefers_rows - = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, ukr_id, cntx ); - const bool_t ukr_prefers_cols - = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, ukr_id, cntx ); - bool_t r_val = FALSE; + const num_t dt = bli_obj_comp_dt( obj ); + const bool ukr_prefers_rows + = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, ukr_id, cntx ); + const bool ukr_prefers_cols + = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, ukr_id, cntx ); + bool r_val = FALSE; if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; @@ -399,15 +496,15 @@ static bool_t bli_cntx_l3_nat_ukr_prefers_storage_of( obj_t* obj, l3ukr_t ukr_id return r_val; } -static bool_t bli_cntx_l3_nat_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_nat_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) { - return ( bool_t ) + return ( bool ) !bli_cntx_l3_nat_ukr_prefers_storage_of( obj, ukr_id, cntx ); } // ----------------------------------------------------------------------------- -static bool_t bli_cntx_l3_vir_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { // For induced methods, return the ukernel storage preferences of the // corresponding real micro-kernel. @@ -420,7 +517,7 @@ static bool_t bli_cntx_l3_vir_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cnt return bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, ukr_id, cntx ); } -static bool_t bli_cntx_l3_vir_ukr_prefers_cols_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_cols_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { // For induced methods, return the ukernel storage preferences of the // corresponding real micro-kernel. @@ -433,16 +530,16 @@ static bool_t bli_cntx_l3_vir_ukr_prefers_cols_dt( num_t dt, l3ukr_t ukr_id, cnt return bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, ukr_id, cntx ); } -static bool_t bli_cntx_l3_vir_ukr_prefers_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) { // Note that we use the computation datatype, which may differ from the // storage datatype of C (when performing a mixed datatype operation). - const num_t dt = bli_obj_comp_dt( obj ); - const bool_t ukr_prefers_rows - = bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, ukr_id, cntx ); - const bool_t ukr_prefers_cols - = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, ukr_id, cntx ); - bool_t r_val = FALSE; + const num_t dt = bli_obj_comp_dt( obj ); + const bool ukr_prefers_rows + = bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, ukr_id, cntx ); + const bool ukr_prefers_cols + = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, ukr_id, cntx ); + bool r_val = FALSE; if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; @@ -450,19 +547,67 @@ static bool_t bli_cntx_l3_vir_ukr_prefers_storage_of( obj_t* obj, l3ukr_t ukr_id return r_val; } -static bool_t bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) +BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t ukr_id, cntx_t* cntx ) { - return ( bool_t ) + return ( bool ) !bli_cntx_l3_vir_ukr_prefers_storage_of( obj, ukr_id, cntx ); } // ----------------------------------------------------------------------------- +BLIS_INLINE bool bli_cntx_l3_sup_ker_prefers_rows_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + const bool prefs = bli_cntx_get_l3_sup_ker_prefs_dt( dt, stor_id, cntx ); + + // A ukernel preference of TRUE means the ukernel prefers row storage. + return ( bool ) + ( prefs == TRUE ); +} + +BLIS_INLINE bool bli_cntx_l3_sup_ker_prefers_cols_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + const bool prefs = bli_cntx_get_l3_sup_ker_prefs_dt( dt, stor_id, cntx ); + + // A ukernel preference of FALSE means the ukernel prefers column storage. + return ( bool ) + ( prefs == FALSE ); +} + +#if 0 +// NOTE: These static functions aren't needed yet. + +BLIS_INLINE bool bli_cntx_l3_sup_ker_prefers_storage_of( obj_t* obj, stor3_t stor_id, cntx_t* cntx ) +{ + const num_t dt = bli_obj_dt( obj ); + const bool ukr_prefers_rows + = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + const bool ukr_prefers_cols + = bli_cntx_l3_sup_ker_prefers_cols_dt( dt, stor_id, cntx ); + bool r_val = FALSE; + + if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; + else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; + + return r_val; +} + +BLIS_INLINE bool bli_cntx_l3_sup_ker_dislikes_storage_of( obj_t* obj, stor3_t stor_id, cntx_t* cntx ) +{ + return ( bool ) + !bli_cntx_l3_sup_ker_prefers_storage_of( obj, stor_id, cntx ); +} +#endif + +// ----------------------------------------------------------------------------- + // // -- cntx_t modification (complex) -------------------------------------------- // -static void bli_cntx_set_blksz( bszid_t bs_id, blksz_t* blksz, bszid_t mult_id, cntx_t* cntx ) +// NOTE: The framework does not use any of the following functions. We provide +// them in order to facilitate creating/modifying custom contexts. + +BLIS_INLINE void bli_cntx_set_blksz( bszid_t bs_id, blksz_t* blksz, bszid_t mult_id, cntx_t* cntx ) { blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); bszid_t* bmults = bli_cntx_bmults_buf( cntx ); @@ -471,63 +616,79 @@ static void bli_cntx_set_blksz( bszid_t bs_id, blksz_t* blksz, bszid_t mult_id, bmults[ bs_id ] = mult_id; } -static void bli_cntx_set_l3_vir_ukr( l3ukr_t ukr_id, func_t* func, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_blksz_def_dt( num_t dt, bszid_t bs_id, dim_t bs, cntx_t* cntx ) +{ + blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); + blksz_t* blksz = &blkszs[ bs_id ]; + + bli_blksz_set_def( bs, dt, blksz ); +} + +BLIS_INLINE void bli_cntx_set_blksz_max_dt( num_t dt, bszid_t bs_id, dim_t bs, cntx_t* cntx ) +{ + blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); + blksz_t* blksz = &blkszs[ bs_id ]; + + bli_blksz_set_max( bs, dt, blksz ); +} + +BLIS_INLINE void bli_cntx_set_l3_vir_ukr( l3ukr_t ukr_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); funcs[ ukr_id ] = *func; } -static void bli_cntx_set_l3_nat_ukr( l3ukr_t ukr_id, func_t* func, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_l3_nat_ukr( l3ukr_t ukr_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_l3_nat_ukrs_buf( cntx ); funcs[ ukr_id ] = *func; } -static void bli_cntx_set_l3_nat_ukr_prefs( l3ukr_t ukr_id, mbool_t* prefs, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_l3_nat_ukr_prefs( l3ukr_t ukr_id, mbool_t* prefs, cntx_t* cntx ) { mbool_t* mbools = bli_cntx_l3_nat_ukrs_prefs_buf( cntx ); mbools[ ukr_id ] = *prefs; } -static void bli_cntx_set_l1f_ker( l1fkr_t ker_id, func_t* func, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_l1f_ker( l1fkr_t ker_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_l1f_kers_buf( cntx ); funcs[ ker_id ] = *func; } -static void bli_cntx_set_l1v_ker( l1vkr_t ker_id, func_t* func, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_l1v_ker( l1vkr_t ker_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_l1v_kers_buf( cntx ); funcs[ ker_id ] = *func; } -static void bli_cntx_set_packm_ker( l1mkr_t ker_id, func_t* func, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_packm_ker( l1mkr_t ker_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_get_packm_kers( ker_id, cntx ); funcs[ ker_id ] = *func; } -static void bli_cntx_set_packm_ker_dt( void* fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_packm_ker_dt( void_fp fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { func_t* func = ( func_t* )bli_cntx_get_packm_kers( ker_id, cntx ); bli_func_set_dt( fp, dt, func ); } -static void bli_cntx_set_unpackm_ker( l1mkr_t ker_id, func_t* func, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_unpackm_ker( l1mkr_t ker_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_get_unpackm_kers( ker_id, cntx ); funcs[ ker_id ] = *func; } -static void bli_cntx_set_unpackm_ker_dt( void* fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +BLIS_INLINE void bli_cntx_set_unpackm_ker_dt( void_fp fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { func_t* func = ( func_t* )bli_cntx_get_unpackm_kers( ker_id, cntx ); @@ -538,18 +699,25 @@ static void bli_cntx_set_unpackm_ker_dt( void* fp, num_t dt, l1mkr_t ker_id, cnt // Function prototypes -BLIS_EXPORT_BLIS void bli_cntx_clear( cntx_t* cntx ); +BLIS_EXPORT_BLIS void bli_cntx_clear( cntx_t* cntx ); + +BLIS_EXPORT_BLIS void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ); + +BLIS_EXPORT_BLIS void bli_cntx_set_ind_blkszs( ind_t method, num_t dt, dim_t n_bs, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_vir_ukrs( dim_t n_ukrs, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_thresh( dim_t n_thresh, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_handlers( dim_t n_ops, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_blkszs( dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_l1f_kers( dim_t n_kers, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_l1v_kers( dim_t n_kers, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_packm_kers( dim_t n_kers, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l1f_kers( dim_t n_kers, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l1v_kers( dim_t n_kers, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_packm_kers( dim_t n_kers, ... ); -BLIS_EXPORT_BLIS void bli_cntx_print( cntx_t* cntx ); +BLIS_EXPORT_BLIS void bli_cntx_print( cntx_t* cntx ); #endif diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index f5c53fc296..ff0f386e65 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2019, Dave Love, University of Manchester Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,19 +46,42 @@ #define __arm__ #endif -#ifndef BLIS_CONFIGURETIME_CPUID - #include "blis.h" -#else +#ifdef BLIS_CONFIGURETIME_CPUID + + // NOTE: If you need to make any changes to this cpp branch, it's probably + // the case that you also need to modify bli_arch.c, bli_cpuid.c, and + // bli_env.c. Don't forget to update these other files as needed! + + // The BLIS_ENABLE_SYSTEM macro must be defined so that the correct cpp + // branch in bli_system.h is processed. (This macro is normally defined in + // bli_config.h.) + #define BLIS_ENABLE_SYSTEM + + // Use C-style static inline functions for any static inline functions that + // happen to be defined by the headers below. (This macro is normally defined + // in bli_config_macro_defs.h.) + #define BLIS_INLINE static + + // Since we're not building a shared library, we can forgo the use of the + // BLIS_EXPORT_BLIS annotations by #defining them to be nothing. (This macro + // is normally defined in bli_config_macro_defs.h.) #define BLIS_EXPORT_BLIS + #include "bli_system.h" #include "bli_type_defs.h" + #include "bli_arch.h" #include "bli_cpuid.h" + //#include "bli_env.h" +#else + #include "blis.h" #endif // ----------------------------------------------------------------------------- #if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) +#include "cpuid.h" + arch_t bli_cpuid_query_id( void ) { uint32_t vendor, family, model, features; @@ -67,6 +91,14 @@ arch_t bli_cpuid_query_id( void ) // vendor. vendor = bli_cpuid_query( &family, &model, &features ); +#if 0 + printf( "vendor = %s\n", vendor==1 ? "AMD": "INTEL" ); + printf("family = %x\n", family ); + printf( "model = %x\n", model ); + + printf( "features = %x\n", features ); +#endif + if ( vendor == VENDOR_INTEL ) { // Check for each Intel configuration that is enabled, check for that @@ -100,6 +132,14 @@ arch_t bli_cpuid_query_id( void ) // Check for each AMD configuration that is enabled, check for that // microarchitecture. We check from most recent to most dated. +#ifdef BLIS_CONFIG_ZEN3 + if ( bli_cpuid_is_zen3( family, model, features ) ) + return BLIS_ARCH_ZEN3; +#endif +#ifdef BLIS_CONFIG_ZEN2 + if ( bli_cpuid_is_zen2( family, model, features ) ) + return BLIS_ARCH_ZEN2; +#endif #ifdef BLIS_CONFIG_ZEN if ( bli_cpuid_is_zen( family, model, features ) ) return BLIS_ARCH_ZEN; @@ -134,7 +174,7 @@ arch_t bli_cpuid_query_id( void ) // ----------------------------------------------------------------------------- -bool_t bli_cpuid_is_skx +bool bli_cpuid_is_skx ( uint32_t family, uint32_t model, @@ -153,13 +193,28 @@ bool_t bli_cpuid_is_skx int nvpu = vpu_count(); - if ( !bli_cpuid_has_features( features, expected ) || nvpu != 2 ) + if ( bli_cpuid_has_features( features, expected ) ) + { + switch ( nvpu ) + { + case 1: + bli_arch_log( "Hardware has 1 FMA unit; using 'haswell' (not 'skx') sub-config.\n" ); + return FALSE; + case 2: + bli_arch_log( "Hardware has 2 FMA units; using 'skx' sub-config.\n" ); + return TRUE; + default: + bli_arch_log( "Number of FMA units unknown; using 'haswell' (not 'skx') config.\n" ); + return FALSE; + } + } + else return FALSE; return TRUE; } -bool_t bli_cpuid_is_knl +bool bli_cpuid_is_knl ( uint32_t family, uint32_t model, @@ -178,7 +233,7 @@ bool_t bli_cpuid_is_knl return TRUE; } -bool_t bli_cpuid_is_haswell +bool bli_cpuid_is_haswell ( uint32_t family, uint32_t model, @@ -195,7 +250,7 @@ bool_t bli_cpuid_is_haswell return TRUE; } -bool_t bli_cpuid_is_sandybridge +bool bli_cpuid_is_sandybridge ( uint32_t family, uint32_t model, @@ -210,7 +265,7 @@ bool_t bli_cpuid_is_sandybridge return TRUE; } -bool_t bli_cpuid_is_penryn +bool bli_cpuid_is_penryn ( uint32_t family, uint32_t model, @@ -228,7 +283,66 @@ bool_t bli_cpuid_is_penryn // ----------------------------------------------------------------------------- -bool_t bli_cpuid_is_zen +bool bli_cpuid_is_zen3 + ( + uint32_t family, + uint32_t model, + uint32_t features + ) +{ + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + // All Zen3 cores have a family of 0x19. + if ( family != 0x19 ) return FALSE; + + // Finally, check for specific models: + // - 0x00 ~ 0xff + // NOTE: We accept any model because the family 25 (0x19) is unique. + const bool is_arch + = + ( 0x00 <= model && model <= 0xff ); + + if ( !is_arch ) return FALSE; + + return TRUE; +} + +bool bli_cpuid_is_zen2 + ( + uint32_t family, + uint32_t model, + uint32_t features + ) +{ + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + // All Zen2 cores have a family of 0x17. + if ( family != 0x17 ) return FALSE; + + // Finally, check for specific models: + // - 0x30 ~ 0xff + // NOTE: We must check model because the family 23 (0x17) is shared with + // zen. + const bool is_arch + = + ( 0x30 <= model && model <= 0xff ); + + if ( !is_arch ) return FALSE; + + return TRUE; +} + +bool bli_cpuid_is_zen ( uint32_t family, uint32_t model, @@ -246,17 +360,19 @@ bool_t bli_cpuid_is_zen if ( family != 0x17 ) return FALSE; // Finally, check for specific models: - // - 0x00-0xff (THIS NEEDS UPDATING) - const bool_t is_arch + // - 0x00 ~ 0x2f + // NOTE: We must check model because the family 23 (0x17) is shared with + // zen2. + const bool is_arch = - ( 0x00 <= model && model <= 0xff ); + ( 0x00 <= model && model <= 0x2f ); if ( !is_arch ) return FALSE; return TRUE; } -bool_t bli_cpuid_is_excavator +bool bli_cpuid_is_excavator ( uint32_t family, uint32_t model, @@ -274,8 +390,8 @@ bool_t bli_cpuid_is_excavator if ( family != 0x15 ) return FALSE; // Finally, check for specific models: - // - 0x60-0x7f - const bool_t is_arch + // - 0x60 ~ 0x7f + const bool is_arch = ( 0x60 <= model && model <= 0x7f ); @@ -284,7 +400,7 @@ bool_t bli_cpuid_is_excavator return TRUE; } -bool_t bli_cpuid_is_steamroller +bool bli_cpuid_is_steamroller ( uint32_t family, uint32_t model, @@ -302,8 +418,8 @@ bool_t bli_cpuid_is_steamroller if ( family != 0x15 ) return FALSE; // Finally, check for specific models: - // - 0x30-0x3f - const bool_t is_arch + // - 0x30 ~ 0x3f + const bool is_arch = ( 0x30 <= model && model <= 0x3f ); @@ -312,7 +428,7 @@ bool_t bli_cpuid_is_steamroller return TRUE; } -bool_t bli_cpuid_is_piledriver +bool bli_cpuid_is_piledriver ( uint32_t family, uint32_t model, @@ -331,8 +447,8 @@ bool_t bli_cpuid_is_piledriver // Finally, check for specific models: // - 0x02 - // - 0x10-0x1f - const bool_t is_arch + // - 0x10 ~ 0x1f + const bool is_arch = model == 0x02 || ( 0x10 <= model && model <= 0x1f ); @@ -341,7 +457,7 @@ bool_t bli_cpuid_is_piledriver return TRUE; } -bool_t bli_cpuid_is_bulldozer +bool bli_cpuid_is_bulldozer ( uint32_t family, uint32_t model, @@ -360,7 +476,7 @@ bool_t bli_cpuid_is_bulldozer // Finally, check for specific models: // - 0x00 // - 0x01 - const bool_t is_arch + const bool is_arch = ( model == 0x00 || model == 0x01 ); @@ -375,30 +491,24 @@ arch_t bli_cpuid_query_id( void ) { uint32_t vendor, model, part, features; - // Call the CPUID instruction and parse its results into a model id, - // part id, and a feature bit field. The return value encodes the - // vendor. vendor = bli_cpuid_query( &model, &part, &features ); - //printf( "vendor = %u\n", vendor ); - //printf( "model = %u\n", model ); - //printf( "part = 0x%x\n", part ); - //printf( "features = %u\n", features ); +#if 0 + printf( "vendor = %u\n", vendor ); + printf( "model = %u\n", model ); + printf( "part = 0x%x\n", part ); + printf( "features = %u\n", features ); +#endif + + if ( vendor == VENDOR_ARM ) { if ( model == MODEL_ARMV8 ) { + return part; // Check for each ARMv8 configuration that is enabled, check for that // microarchitecture. We check from most recent to most dated. -#ifdef BLIS_CONFIG_THUNDERX2 - if ( bli_cpuid_is_thunderx2( model, part, features ) ) - return BLIS_ARCH_THUNDERX2; -#endif -#ifdef BLIS_CONFIG_CORTEXA57 - if ( bli_cpuid_is_cortexa57( model, part, features ) ) - return BLIS_ARCH_CORTEXA57; -#endif // If none of the other sub-configurations were detected, return // the 'generic' arch_t id value. return BLIS_ARCH_GENERIC; @@ -428,52 +538,7 @@ arch_t bli_cpuid_query_id( void ) return BLIS_ARCH_GENERIC; } -bool_t bli_cpuid_is_thunderx2 - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_NEON; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool_t bli_cpuid_is_cortexa57 - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_NEON; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool_t bli_cpuid_is_cortexa53 - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_NEON; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool_t bli_cpuid_is_cortexa15 +bool bli_cpuid_is_cortexa15 ( uint32_t family, uint32_t model, @@ -483,12 +548,10 @@ bool_t bli_cpuid_is_cortexa15 // Check for expected CPU features. const uint32_t expected = FEATURE_NEON; - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; + return bli_cpuid_has_features( features, expected ) && model == 0xc0f; } -bool_t bli_cpuid_is_cortexa9 +bool bli_cpuid_is_cortexa9 ( uint32_t family, uint32_t model, @@ -498,9 +561,7 @@ bool_t bli_cpuid_is_cortexa9 // Check for expected CPU features. const uint32_t expected = FEATURE_NEON; - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; + return bli_cpuid_has_features( features, expected ) && model == 0xc09; } #endif @@ -517,7 +578,7 @@ bool_t bli_cpuid_is_cortexa9 Copyright (C) 2017, The University of Texas at Austin Copyright (C) 2017, Devin Matthews - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -847,6 +908,10 @@ void get_cpu_name( char *cpu_name ) *( uint32_t* )&cpu_name[32+12] = edx; } +// Return the number of FMA units _assuming avx512 is supported_. +// This needs updating for new processor types, sigh. +// See https://ark.intel.com/content/www/us/en/ark.html#@Processors +// and also https://github.com/jeffhammond/vpu-count int vpu_count( void ) { char cpu_name[48] = {}; @@ -858,56 +923,272 @@ int vpu_count( void ) if ( strstr( cpu_name, "Intel(R) Xeon(R)" ) != NULL ) { - loc = strstr( cpu_name, "Platinum" ); + if (( loc = strstr( cpu_name, "Platinum" ) )) + return 2; if ( loc == NULL ) - loc = strstr( cpu_name, "Gold" ); + loc = strstr( cpu_name, "Gold" ); // 1 or 2, tested below if ( loc == NULL ) - loc = strstr( cpu_name, "Silver" ); + if (( loc = strstr( cpu_name, "Silver" ) )) + return 1; if ( loc == NULL ) - loc = strstr( cpu_name, "Bronze" ); + if (( loc = strstr( cpu_name, "Bronze" ) )) + return 1; if ( loc == NULL ) loc = strstr( cpu_name, "W" ); + if ( loc == NULL ) + if (( loc = strstr( cpu_name, "D" ) )) + // Fixme: May be wrong + // + return 1; if ( loc == NULL ) return -1; - loc = strstr( loc+1, " " ); + // We may have W-nnnn rather than, say, Gold nnnn + if ( 'W' == *loc && '-' == *(loc+1) ) + loc++; + else + loc = strstr( loc+1, " " ); if ( loc == NULL ) return -1; strncpy( model_num, loc+1, 4 ); - model_num[4] = '\0'; + model_num[4] = '\0'; // Things like i9-10900X matched above sku = atoi( model_num ); + // These were derived from ARK listings as of 2019-10-09, but + // may not be complete, especially as the ARK Skylake listing + // seems to be limited. if ( 8199 >= sku && sku >= 8100 ) return 2; else if ( 6199 >= sku && sku >= 6100 ) return 2; else if ( sku == 5122 ) return 2; + else if ( 6299 >= sku && sku >= 6200 ) return 2; // Cascade Lake Gold + else if ( 5299 >= sku && sku >= 5200 ) return 1; // Cascade Lake Gold else if ( 5199 >= sku && sku >= 5100 ) return 1; else if ( 4199 >= sku && sku >= 4100 ) return 1; else if ( 3199 >= sku && sku >= 3100 ) return 1; + else if ( 3299 >= sku && sku >= 3200 ) return 2; // Cascade Lake W + else if ( 2299 >= sku && sku >= 2200 ) return 2; // Cascade Lake W else if ( 2199 >= sku && sku >= 2120 ) return 2; + else if ( 2102 == sku || sku == 2104 ) return 2; // Gold exceptions else if ( 2119 >= sku && sku >= 2100 ) return 1; else return -1; } - else if ( strstr( cpu_name, "Intel(R) Core(TM) i9" ) != NULL ) + else if ( strstr( cpu_name, "Intel(R) Core(TM)" ) != NULL ) + return 2; // All i7/i9 with avx512? + else { - return 1; + return -1; } - else if ( strstr( cpu_name, "Intel(R) Core(TM) i7" ) != NULL ) +} + +#elif defined(__aarch64__) + +#ifdef __linux__ +// This is adapted from OpenBLAS. See +// https://www.kernel.org/doc/html/latest/arm64/cpu-feature-registers.html +// for the mechanism, but not the magic numbers. + +// Fixme: Could these be missing in older Linux? +#include +#include + +#ifndef HWCAP_CPUID +#define HWCAP_CPUID (1 << 11) +#endif +/* From https://www.kernel.org/doc/html/latest/arm64/sve.html and the + aarch64 hwcap.h */ +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif +/* Maybe also for AT_HWCAP2 +#define HWCAP2_SVE2(1 << 1) +et al +) */ + +#endif //__linux__ + +#ifdef __APPLE__ +#include +// #include +#endif + +static uint32_t get_coretype + ( + uint32_t* features + ) +{ + int implementer = 0x00, part = 0x000; + *features = FEATURE_NEON; + +#ifdef __linux__ + if ( getauxval( AT_HWCAP ) & HWCAP_CPUID ) { - if ( strstr( cpu_name, "7800X" ) != NULL || - strstr( cpu_name, "7820X" ) != NULL ) - return 1; - else - return -1; + // Also available from + // /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 + // and split out in /proc/cpuinfo (with a tab before the colon): + // CPU part : 0x0a1 + + uint64_t midr_el1; + __asm("mrs %0, MIDR_EL1" : "=r" (midr_el1)); + /* + * MIDR_EL1 + * + * 31 24 23 20 19 16 15 4 3 0 + * ----------------------------------------------------------------- + * | Implementer | Variant | Architecture | Part Number | Revision | + * ----------------------------------------------------------------- + */ + implementer = (midr_el1 >> 24) & 0xFF; + part = (midr_el1 >> 4) & 0xFFF; } - else + + bool has_sve = getauxval( AT_HWCAP ) & HWCAP_SVE; + if (has_sve) + *features |= FEATURE_SVE; +#endif //__linux__ + +#ifdef __APPLE__ + // Better values could be obtained from sysctlbyname() + implementer = 0x61; //Apple + part = 0x023; //Firestorm +#endif //__APPLE__ + + // From Linux arch/arm64/include/asm/cputype.h + // ARM_CPU_IMP_ARM 0x41 + // ARM_CPU_IMP_APM 0x50 + // ARM_CPU_IMP_CAVIUM 0x43 + // ARM_CPU_IMP_BRCM 0x42 + // ARM_CPU_IMP_QCOM 0x51 + // ARM_CPU_IMP_NVIDIA 0x4E + // ARM_CPU_IMP_FUJITSU 0x46 + // ARM_CPU_IMP_HISI 0x48 + // ARM_CPU_IMP_APPLE 0x61 + // + // ARM_CPU_PART_AEM_V8 0xD0F + // ARM_CPU_PART_FOUNDATION 0xD00 + // ARM_CPU_PART_CORTEX_A57 0xD07 + // ARM_CPU_PART_CORTEX_A72 0xD08 + // ARM_CPU_PART_CORTEX_A53 0xD03 + // ARM_CPU_PART_CORTEX_A73 0xD09 + // ARM_CPU_PART_CORTEX_A75 0xD0A + // ARM_CPU_PART_CORTEX_A35 0xD04 + // ARM_CPU_PART_CORTEX_A55 0xD05 + // ARM_CPU_PART_CORTEX_A76 0xD0B + // ARM_CPU_PART_NEOVERSE_N1 0xD0C + // ARM_CPU_PART_CORTEX_A77 0xD0D + // from GCC: + // ARM_CPU_PART_CORTEX_A78 0xd41 + // ARM_CPU_PART_CORTEX_X1 0xd44 + // ARM_CPU_PART_CORTEX_V1 0xd40 + // ARM_CPU_PART_CORTEX_N2 0xd49 + // ARM_CPU_PART_CORTEX_R82 0xd15 + // + // APM_CPU_PART_POTENZA 0x000 + // + // CAVIUM_CPU_PART_THUNDERX 0x0A1 + // CAVIUM_CPU_PART_THUNDERX_81XX 0x0A2 + // CAVIUM_CPU_PART_THUNDERX_83XX 0x0A3 + // CAVIUM_CPU_PART_THUNDERX2 0x0AF + // CAVIUM_CPU_PART_THUNDERX3 0x0B8 // taken from OpenBLAS + // + // BRCM_CPU_PART_BRAHMA_B53 0x100 + // BRCM_CPU_PART_VULCAN 0x516 + // + // QCOM_CPU_PART_FALKOR_V1 0x800 + // QCOM_CPU_PART_FALKOR 0xC00 + // QCOM_CPU_PART_KRYO 0x200 + // QCOM_CPU_PART_KRYO_3XX_SILVER 0x803 + // QCOM_CPU_PART_KRYO_4XX_GOLD 0x804 + // QCOM_CPU_PART_KRYO_4XX_SILVER 0x805 + // + // NVIDIA_CPU_PART_DENVER 0x003 + // NVIDIA_CPU_PART_CARMEL 0x004 + // + // FUJITSU_CPU_PART_A64FX 0x001 + // + // HISI_CPU_PART_TSV110 0xD01 + + // APPLE_CPU_PART_M1_ICESTORM 0x022 + // APPLE_CPU_PART_M1_FIRESTORM 0x023 + + // Fixme: After merging the vpu_count branch we could report the + // part here with bli_dolog. + switch(implementer) { - return -1; + case 0x41: // ARM + switch (part) + { +#ifdef BLIS_CONFIG_CORTEXA57 + case 0xd07: // Cortex A57 + return BLIS_ARCH_CORTEXA57; +#endif +#ifdef BLIS_CONFIG_CORTEXA53 + case 0xd03: // Cortex A53 + return BLIS_ARCH_CORTEXA53; +#endif +#ifdef BLIS_CONFIG_THUNDERX2 + case 0xd0c: // Neoverse N1 (and Graviton G2?) + return BLIS_ARCH_THUNDERX2; //placeholder for N1 +#endif + } + break; + case 0x42: // Broadcom + switch (part) + { +#ifdef BLIS_CONFIG_THUNDERX2 + case 0x516: // Vulcan + return BLIS_ARCH_THUNDERX2; +#endif + } + break; + case 0x43: // Cavium + switch (part) + { +#ifdef BLIS_CONFIG_THUNDERX2 + case 0x0af: // ThunderX2 + case 0x0b8: // ThunderX3 + return BLIS_ARCH_THUNDERX2; +#endif + } + break; + case 0x46: // Fujitsu + switch (part) + { +#ifdef BLIS_CONFIG_A64FX + case 0x001: // A64FX + return BLIS_ARCH_A64FX; +#endif + } + break; + case 0x61: // Apple + switch (part) + { +#ifdef BLIS_CONFIG_FIRESTORM + case 0x022: // Icestorm (M1.LITTLE) + case 0x023: // Firestorm (M1.big) + return BLIS_ARCH_FIRESTORM; +#endif + } + break; } -} -#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) +#ifdef BLIS_CONFIG_ARMSVE + if (has_sve) + return BLIS_ARCH_ARMSVE; +#endif + +// Can't use #if defined(...) here because of parsing done for autoconfiguration +#ifdef BLIS_CONFIG_CORTEXA57 + return BLIS_ARCH_CORTEXA57; +#else +#ifdef BLIS_CONFIG_CORTEXA53 + return BLIS_ARCH_CORTEXA53; +#else + return BLIS_ARCH_GENERIC; +#endif +#endif +} uint32_t bli_cpuid_query ( @@ -916,99 +1197,79 @@ uint32_t bli_cpuid_query uint32_t* features ) { - *model = MODEL_UNKNOWN; - *part = 0; - *features = 0; - -#if 1 - const char* grep_str1 = "grep -m 1 Processor /proc/cpuinfo"; - const char* grep_str2 = "grep -m 1 'CPU part' /proc/cpuinfo"; - const char* grep_str3 = "grep -m 1 Features /proc/cpuinfo"; -#else - const char* grep_str1 = "grep -m 1 Processor ./proc_cpuinfo"; - const char* grep_str2 = "grep -m 1 'CPU part' ./proc_cpuinfo"; - const char* grep_str3 = "grep -m 1 Features ./proc_cpuinfo"; -#endif + *model = MODEL_ARMV8; + *part = get_coretype(features); - FILE *fd1 = popen( grep_str1, "r"); - if ( !fd1 ) - { - //printf("popen 1 failed\n"); - return VENDOR_ARM; - } - FILE *fd2 = popen( grep_str2, "r"); - if (!fd2) - { - //printf("popen 2 failed\n"); - pclose(fd1); - return VENDOR_ARM; - } - FILE *fd3 = popen( grep_str3, "r"); - if (!fd3) - { - //printf("popen 3 failed\n"); - pclose(fd1); - pclose(fd2); - return VENDOR_ARM; - } + return VENDOR_ARM; +} - uint32_t n1, n2, n3; - int c; +#elif defined(__arm__) || defined(_M_ARM) - // First, discover how many chars are in each stream. - for ( n1 = 0; (c = fgetc(fd1)) != EOF; ++n1 ) continue; - for ( n2 = 0; (c = fgetc(fd2)) != EOF; ++n2 ) continue; - for ( n3 = 0; (c = fgetc(fd3)) != EOF; ++n3 ) continue; +/* + I can't easily find documentation to do this as for aarch64, though + it presumably could be unearthed from Linux code. However, on + Linux 5.2 (and Androids's 3.4), /proc/cpuinfo has this sort of + thing, used below: - //printf( "n1, n2, n3 = %u %u %u\n", n1, n2, n3 ); + CPU implementer : 0x41 + CPU architecture: 7 + CPU variant : 0x3 + CPU part : 0xc09 - // Close the streams. - pclose( fd1 ); - pclose( fd2 ); - pclose( fd3 ); + The complication for family selection is that Neon is optional for + CortexA9, for instance. That's tested in bli_cpuid_is_cortexa9. + */ - // Allocate the correct amount of memory for each stream. - char* proc_str = malloc( ( size_t )( n1 + 1 ) ); - char* ptno_str = malloc( ( size_t )( n2 + 1 ) ); - char* feat_str = malloc( ( size_t )( n3 + 1 ) ); - *proc_str = 0; - *ptno_str = 0; - *feat_str = 0; +#define TEMP_BUFFER_SIZE 200 - // Re-open the streams. Note that there is no need to check for errors - // this time since we're assumign that the contents of /proc/cpuinfo - // will be the same as before. - fd1 = popen( grep_str1, "r"); - fd2 = popen( grep_str2, "r"); - fd3 = popen( grep_str3, "r"); +uint32_t bli_cpuid_query + ( + uint32_t* model, + uint32_t* part, + uint32_t* features + ) +{ + *model = MODEL_UNKNOWN; + *part = 0; + *features = 0; + char* pci_str = "/proc/cpuinfo"; + + char proc_str[ TEMP_BUFFER_SIZE ]; + char ptno_str[ TEMP_BUFFER_SIZE ]; + char feat_str[ TEMP_BUFFER_SIZE ]; char* r_val; - // Now read each stream in its entirety. Nothing should go wrong, but - // if it does, bail out. - r_val = fgets( proc_str, n1, fd1 ); - if ( n1 && r_val == NULL ) bli_abort(); + //printf( "bli_cpuid_query(): beginning search\n" ); - r_val = fgets( ptno_str, n2, fd2 ); - if ( n2 && r_val == NULL ) bli_abort(); + // Search /proc/cpuinfo for the 'Processor' entry. + r_val = find_string_in( "Processor", proc_str, TEMP_BUFFER_SIZE, pci_str ); + if ( r_val == NULL ) return VENDOR_ARM; - r_val = fgets( feat_str, n3, fd3 ); - if ( n3 && r_val == NULL ) bli_abort(); + // Search /proc/cpuinfo for the 'CPU part' entry. + r_val = find_string_in( "CPU part", ptno_str, TEMP_BUFFER_SIZE, pci_str ); + if ( r_val == NULL ) return VENDOR_ARM; - //printf( "proc_str: %s\n", proc_str ); - //printf( "ptno_str: %s\n", ptno_str ); - //printf( "feat_str: %s\n", feat_str ); + // Search /proc/cpuinfo for the 'Features' entry. + r_val = find_string_in( "Features", feat_str, TEMP_BUFFER_SIZE, pci_str ); + if ( r_val == NULL ) return VENDOR_ARM; - // Close the streams. - pclose( fd1 ); - pclose( fd2 ); - pclose( fd3 ); +#if 0 + printf( "bli_cpuid_query(): full processor string: %s\n", proc_str ); + printf( "bli_cpuid_query(): full part num string: %s\n", ptno_str ); + printf( "bli_cpuid_query(): full features string: %s\n", feat_str ); +#endif // Parse the feature string to check for SIMD features. if ( strstr( feat_str, "neon" ) != NULL || strstr( feat_str, "asimd" ) != NULL ) *features |= FEATURE_NEON; - //printf( "features var: %u\n", *features ); + + // Parse the feature string to check for SVE features. + if ( strstr( feat_str, "sve" ) != NULL ) + *features |= FEATURE_SVE; + + //printf( "bli_cpuid_query(): features var: %u\n", *features ); // Parse the processor string to uncover the model. if ( strstr( proc_str, "ARMv7" ) != NULL ) @@ -1016,7 +1277,8 @@ uint32_t bli_cpuid_query else if ( strstr( proc_str, "AArch64" ) != NULL || strstr( proc_str, "ARMv8" ) ) *model = MODEL_ARMV8; - //printf( "model: %u\n", *model ); + + //printf( "bli_cpuid_query(): model: %u\n", *model ); // Parse the part number string. r_val = strstr( ptno_str, "0x" ); @@ -1024,9 +1286,69 @@ uint32_t bli_cpuid_query { *part = strtol( r_val, NULL, 16 ); } - //printf( "part#: %x\n", *part ); + //printf( "bli_cpuid_query(): part#: %x\n", *part ); return VENDOR_ARM; } +char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ) +{ + // This function searches for the first line of the file located at + // 'filepath' that contains the string 'target' and then copies that + // line (actually, the substring of the line starting with 'target') + // to 'buffer', which is 'buf_len' bytes long. + + char* r_val = NULL; + + // Allocate a temporary local buffer equal to the size of buffer. + char* buf_local = malloc( buf_len * sizeof( char ) ); + + // Open the file stream. + FILE* stream = fopen( filepath, "r" ); + + // Repeatedly read in a line from the stream, storing the contents of + // the stream into buf_local. + while ( !feof( stream ) ) + { + // Read in the current line, up to buf_len-1 bytes. + r_val = fgets( buf_local, buf_len-1, stream ); + + //printf( "read line: %s", buf_local ); + + // fgets() returns the pointer specified by the first argument (in + // this case, buf_local) on success and NULL on error. + if ( r_val == NULL ) break; + + // Since fgets() was successful, we can search for the target string + // within the current line, as captured in buf_local. + r_val = strstr( buf_local, target ); + + // If the target string was found in buf_local, we save it to buffer. + if ( r_val != NULL ) + { + //printf( " found match to '%s'\n", target ); + + // Copy the string read by fgets() to the caller's buffer. + strncpy( buffer, buf_local, buf_len ); + + // Make sure that we have a terminating null character by the + // end of the buffer. + if ( buf_len > 0 ) buffer[ buf_len - 1 ] = '\0'; + + // Leave the loop since we found the target string. + break; + } + } + + // Close the file stream. + fclose( stream ); + + // Free the temporary local buffer. + free( buf_local ); + + // Return r_val so the caller knows if we failed. + return r_val; +} + #endif + diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index e609dcbd24..3fea78e5a3 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,35 +44,39 @@ BLIS_ARCH_CORTEXA9 = 12, BLIS_ARCH_GENERIC = 13 } arch_t; - typedef uint64_t bool_t; + typedef uint64_t bool; #define bli_abort abort #endif #ifndef BLIS_CPUID_H #define BLIS_CPUID_H -arch_t bli_cpuid_query_id( void ); +arch_t bli_cpuid_query_id( void ); // Intel -bool_t bli_cpuid_is_skx( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_knl( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_haswell( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_skx( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_knl( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_haswell( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); // AMD -bool_t bli_cpuid_is_zen( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_excavator( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_steamroller( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_piledriver( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_bulldozer( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_zen3( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_zen2( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_zen( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_excavator( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_steamroller( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_piledriver( uint32_t family, uint32_t model, uint32_t features ); +bool bli_cpuid_is_bulldozer( uint32_t family, uint32_t model, uint32_t features ); // ARM -bool_t bli_cpuid_is_thunderx2( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa57( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa53( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa15( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa9( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_thunderx2( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_cortexa57( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_cortexa53( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_armsve( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_a64fx( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_cortexa15( uint32_t model, uint32_t part, uint32_t features ); +bool bli_cpuid_is_cortexa9( uint32_t model, uint32_t part, uint32_t features ); uint32_t bli_cpuid_query( uint32_t* family, uint32_t* model, uint32_t* features ); @@ -114,7 +119,7 @@ uint32_t bli_cpuid_query( uint32_t* family, uint32_t* model, uint32_t* features */ -static bool_t bli_cpuid_has_features( uint32_t have, uint32_t want ) +BLIS_INLINE bool bli_cpuid_has_features( uint32_t have, uint32_t want ) { return ( have & want ) == want; } @@ -123,7 +128,9 @@ static bool_t bli_cpuid_has_features( uint32_t have, uint32_t want ) #if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) -#include "cpuid.h" +// cpuid.h is now #included in bli_cpuid.c instead of here. See issue #393 +// for more information why this move was made. +//#include "cpuid.h" void get_cpu_name( char *cpu_name ); int vpu_count( void ); @@ -156,6 +163,8 @@ enum #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) +char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ); + enum { VENDOR_ARM = 0, @@ -169,7 +178,8 @@ enum }; enum { - FEATURE_NEON = 0x1 + FEATURE_NEON = 0x01, + FEATURE_SVE = 0x02 }; #endif diff --git a/frame/base/bli_env.c b/frame/base/bli_env.c new file mode 100644 index 0000000000..92aba69700 --- /dev/null +++ b/frame/base/bli_env.c @@ -0,0 +1,127 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef BLIS_CONFIGURETIME_CPUID + + // NOTE: If you need to make any changes to this cpp branch, it's probably + // the case that you also need to modify bli_arch.c, bli_cpuid.c, and + // bli_env.c. Don't forget to update these other files as needed! + + // The BLIS_ENABLE_SYSTEM macro must be defined so that the correct cpp + // branch in bli_system.h is processed. (This macro is normally defined in + // bli_config.h.) + #define BLIS_ENABLE_SYSTEM + + // Use C-style static inline functions for any static inline functions that + // happen to be defined by the headers below. (This macro is normally defined + // in bli_config_macro_defs.h.) + #define BLIS_INLINE static + + // Since we're not building a shared library, we can forgo the use of the + // BLIS_EXPORT_BLIS annotations by #defining them to be nothing. (This macro + // is normally defined in bli_config_macro_defs.h.) + #define BLIS_EXPORT_BLIS + + #include "bli_system.h" + #include "bli_type_defs.h" + //#include "bli_arch.h" + //#include "bli_cpuid.h" + #include "bli_env.h" +#else + #include "blis.h" +#endif + +// ----------------------------------------------------------------------------- + +gint_t bli_env_get_var( const char* env, gint_t fallback ) +{ + gint_t r_val; + char* str; + + // Query the environment variable and store the result in str. + str = getenv( env ); + + // Set the return value based on the string obtained from getenv(). + if ( str != NULL ) + { + // If there was no error, convert the string to an integer and + // prepare to return that integer. + r_val = ( gint_t )strtol( str, NULL, 10 ); + } + else + { + // If there was an error, use the "fallback" as the return value. + r_val = fallback; + } + + return r_val; +} + +#if 0 +#ifdef _MSC_VER +#define strerror_r(errno,buf,len) strerror_s(buf,len,errno) +#endif + +void bli_env_set_var( const char* env, dim_t value ) +{ + dim_t r_val; + char value_str[32]; + const char* fs_32 = "%u"; + const char* fs_64 = "%lu"; + + // Convert the string to an integer, but vary the format specifier + // depending on the integer type size. + if ( bli_info_get_int_type_size() == 32 ) sprintf( value_str, fs_32, value ); + else sprintf( value_str, fs_64, value ); + + // Set the environment variable using the string we just wrote to via + // sprintf(). (The 'TRUE' argument means we want to overwrite the current + // value if the environment variable already exists.) + r_val = bli_setenv( env, value_str, TRUE ); + + // Check the return value in case something went horribly wrong. + if ( r_val == -1 ) + { + char err_str[128]; + + // Query the human-readable error string corresponding to errno. + strerror_r( errno, err_str, 128 ); + + // Print the error message. + bli_print_msg( err_str, __FILE__, __LINE__ ); + } +} +#endif + diff --git a/frame/3/bli_l3_packm.h b/frame/base/bli_env.h similarity index 88% rename from frame/3/bli_l3_packm.h rename to frame/base/bli_env.h index 37b1db1058..de86fadff0 100644 --- a/frame/3/bli_l3_packm.h +++ b/frame/base/bli_env.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2016, Hewlett Packard Enterprise Development LP Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without @@ -33,13 +34,11 @@ */ -void bli_l3_packm - ( - obj_t* x, - obj_t* x_pack, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ); +#ifndef BLIS_ENV_H +#define BLIS_ENV_H + +gint_t bli_env_get_var( const char* env, gint_t fallback ); +//void bli_env_set_var( const char* env, dim_t value ); + +#endif diff --git a/frame/base/bli_error.c b/frame/base/bli_error.c index 8ed386af53..37add3b674 100644 --- a/frame/base/bli_error.c +++ b/frame/base/bli_error.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +36,7 @@ #include "blis.h" // Internal array to hold error strings. -static char bli_error_string[BLIS_MAX_NUM_ERR_MSGS][BLIS_MAX_ERR_MSG_LENGTH] = +static char *bli_error_string[-BLIS_ERROR_CODE_MAX] = { [-BLIS_INVALID_ERROR_CHECKING_LEVEL] = "Invalid error checking level.", [-BLIS_UNDEFINED_ERROR_CODE] = "Undefined error code.", @@ -104,6 +104,7 @@ static char bli_error_string[BLIS_MAX_NUM_ERR_MSGS][BLIS_MAX_ERR_MSG_LENGTH] = [-BLIS_EXPECTED_OBJECT_ALIAS] = "Expected object to be alias.", [-BLIS_INVALID_ARCH_ID] = "Invalid architecture id value.", + [-BLIS_UNINITIALIZED_GKS_CNTX] = "Accessed uninitialized context in gks; BLIS_ARCH_TYPE is probably set to an invalid architecture id.", [-BLIS_MC_DEF_NONMULTIPLE_OF_MR] = "Default MC is non-multiple of MR for one or more datatypes.", [-BLIS_MC_MAX_NONMULTIPLE_OF_MR] = "Maximum MC is non-multiple of MR for one or more datatypes.", @@ -132,11 +133,8 @@ void bli_abort( void ) // ----------------------------------------------------------------------------- -// A mutex to allow synchronous access to bli_err_chk_level. -static bli_pthread_mutex_t err_mutex = BLIS_PTHREAD_MUTEX_INITIALIZER; - // Current error checking level. -static errlev_t bli_err_chk_level = BLIS_FULL_ERROR_CHECKING; +static BLIS_THREAD_LOCAL errlev_t bli_err_chk_level = BLIS_FULL_ERROR_CHECKING; errlev_t bli_error_checking_level( void ) { @@ -150,20 +148,10 @@ void bli_error_checking_level_set( errlev_t new_level ) e_val = bli_check_valid_error_level( new_level ); bli_check_error_code( e_val ); - // Acquire the mutex protecting bli_err_chk_level. - bli_pthread_mutex_lock( &err_mutex ); - - // BEGIN CRITICAL SECTION - { - bli_err_chk_level = new_level; - } - // END CRITICAL SECTION - - // Release the mutex protecting bli_err_chk_level. - bli_pthread_mutex_unlock( &err_mutex ); + bli_err_chk_level = new_level; } -bool_t bli_error_checking_is_enabled( void ) +bool bli_error_checking_is_enabled( void ) { return bli_error_checking_level() != BLIS_NO_ERROR_CHECKING; } diff --git a/frame/base/bli_error.h b/frame/base/bli_error.h index e04c6784d5..e6e6f35dde 100644 --- a/frame/base/bli_error.h +++ b/frame/base/bli_error.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,10 +37,10 @@ BLIS_EXPORT_BLIS errlev_t bli_error_checking_level( void ); BLIS_EXPORT_BLIS void bli_error_checking_level_set( errlev_t new_level ); -BLIS_EXPORT_BLIS bool_t bli_error_checking_is_enabled( void ); +BLIS_EXPORT_BLIS bool bli_error_checking_is_enabled( void ); -void bli_print_msg( char* str, char* file, guint_t line ); -void bli_abort( void ); +void bli_print_msg( char* str, char* file, guint_t line ); +BLIS_EXPORT_BLIS void bli_abort( void ); -char* bli_error_string_for_code( gint_t code ); +char* bli_error_string_for_code( gint_t code ); diff --git a/frame/base/bli_func.c b/frame/base/bli_func.c index 435bd81def..477710ff00 100644 --- a/frame/base/bli_func.c +++ b/frame/base/bli_func.c @@ -37,15 +37,16 @@ func_t* bli_func_create ( - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ) { func_t* f; + err_t r_val; - f = ( func_t* ) bli_malloc_intl( sizeof(func_t) ); + f = ( func_t* )bli_malloc_intl( sizeof( func_t ), &r_val ); bli_func_init ( @@ -62,10 +63,10 @@ func_t* bli_func_create void bli_func_init ( func_t* f, - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ) { bli_func_set_dt( ptr_s, BLIS_FLOAT, f ); @@ -92,16 +93,16 @@ void bli_func_free( func_t* f ) // ----------------------------------------------------------------------------- -bool_t bli_func_is_null_dt( num_t dt, - func_t* f ) +bool bli_func_is_null_dt( num_t dt, + func_t* f ) { return ( bli_func_get_dt( dt, f ) == NULL ); } -bool_t bli_func_is_null( func_t* f ) +bool bli_func_is_null( func_t* f ) { - bool_t r_val = TRUE; - num_t dt; + bool r_val = TRUE; + num_t dt; // Iterate over all floating-point datatypes. If any is non-null, // return FALSE. Otherwise, if they are all null, return TRUE. diff --git a/frame/base/bli_func.h b/frame/base/bli_func.h index 0f927ad81a..7bdd1ab10e 100644 --- a/frame/base/bli_func.h +++ b/frame/base/bli_func.h @@ -36,7 +36,7 @@ // func_t query -static void* bli_func_get_dt +BLIS_INLINE void_fp bli_func_get_dt ( num_t dt, func_t* func @@ -47,9 +47,9 @@ static void* bli_func_get_dt // func_t modification -static void bli_func_set_dt +BLIS_INLINE void bli_func_set_dt ( - void* fp, + void_fp fp, num_t dt, func_t* func ) @@ -57,13 +57,13 @@ static void bli_func_set_dt func->ptr[ dt ] = fp; } -static void bli_func_copy_dt +BLIS_INLINE void bli_func_copy_dt ( num_t dt_src, func_t* func_src, num_t dt_dst, func_t* func_dst ) { - void* fp = bli_func_get_dt( dt_src, func_src ); + void_fp fp = bli_func_get_dt( dt_src, func_src ); bli_func_set_dt( fp, dt_dst, func_dst ); } @@ -72,19 +72,19 @@ static void bli_func_copy_dt func_t* bli_func_create ( - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ); void bli_func_init ( func_t* f, - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ); void bli_func_init_null @@ -96,7 +96,7 @@ void bli_func_free( func_t* f ); // ----------------------------------------------------------------------------- -bool_t bli_func_is_null_dt( num_t dt, - func_t* f ); -bool_t bli_func_is_null( func_t* f ); +bool bli_func_is_null_dt( num_t dt, + func_t* f ); +bool bli_func_is_null( func_t* f ); diff --git a/frame/base/bli_getopt.c b/frame/base/bli_getopt.c index 2222234848..184439db59 100644 --- a/frame/base/bli_getopt.c +++ b/frame/base/bli_getopt.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_gks.c b/frame/base/bli_gks.c index 29902e8a4c..cc17b33ffb 100644 --- a/frame/base/bli_gks.c +++ b/frame/base/bli_gks.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,16 +41,16 @@ static cntx_t** gks[ BLIS_NUM_ARCHS ]; // The array of function pointers holding the registered context initialization // functions for induced methods. -static void* cntx_ind_init[ BLIS_NUM_ARCHS ]; +static void_fp cntx_ind_init[ BLIS_NUM_ARCHS ]; // The array of function pointers holding the registered context initialization // functions for reference kernels. -static void* cntx_ref_init[ BLIS_NUM_ARCHS ]; +static void_fp cntx_ref_init[ BLIS_NUM_ARCHS ]; // Define a function pointer type for context initialization functions. typedef void (*nat_cntx_init_ft)( cntx_t* cntx ); typedef void (*ref_cntx_init_ft)( cntx_t* cntx ); -typedef void (*ind_cntx_init_ft)( ind_t method, num_t dt, cntx_t* cntx ); +typedef void (*ind_cntx_init_ft)( ind_t method, cntx_t* cntx ); // ----------------------------------------------------------------------------- @@ -97,6 +97,16 @@ void bli_gks_init( void ) #endif // AMD architectures +#ifdef BLIS_CONFIG_ZEN3 + bli_gks_register_cntx( BLIS_ARCH_ZEN3, bli_cntx_init_zen3, + bli_cntx_init_zen3_ref, + bli_cntx_init_zen3_ind ); +#endif +#ifdef BLIS_CONFIG_ZEN2 + bli_gks_register_cntx( BLIS_ARCH_ZEN2, bli_cntx_init_zen2, + bli_cntx_init_zen2_ref, + bli_cntx_init_zen2_ind ); +#endif #ifdef BLIS_CONFIG_ZEN bli_gks_register_cntx( BLIS_ARCH_ZEN, bli_cntx_init_zen, bli_cntx_init_zen_ref, @@ -124,6 +134,11 @@ void bli_gks_init( void ) #endif // ARM architectures +#ifdef BLIS_CONFIG_A64FX + bli_gks_register_cntx( BLIS_ARCH_A64FX, bli_cntx_init_a64fx, + bli_cntx_init_a64fx_ref, + bli_cntx_init_a64fx_ind ); +#endif #ifdef BLIS_CONFIG_THUNDERX2 bli_gks_register_cntx( BLIS_ARCH_THUNDERX2, bli_cntx_init_thunderx2, bli_cntx_init_thunderx2_ref, @@ -135,10 +150,25 @@ void bli_gks_init( void ) bli_cntx_init_cortexa57_ind ); #endif #ifdef BLIS_CONFIG_CORTEXA53 - bli_gks_register_cntx( BLIS_ARCH_CORTEXA57, bli_cntx_init_cortexa53, + bli_gks_register_cntx( BLIS_ARCH_CORTEXA53, bli_cntx_init_cortexa53, bli_cntx_init_cortexa53_ref, bli_cntx_init_cortexa53_ind ); #endif +#ifdef BLIS_CONFIG_ARMSVE + bli_gks_register_cntx( BLIS_ARCH_ARMSVE, bli_cntx_init_armsve, + bli_cntx_init_armsve_ref, + bli_cntx_init_armsve_ind ); +#endif +#ifdef BLIS_CONFIG_A64FX + bli_gks_register_cntx( BLIS_ARCH_A64FX, bli_cntx_init_a64fx, + bli_cntx_init_a64fx_ref, + bli_cntx_init_a64fx_ind ); +#endif +#ifdef BLIS_CONFIG_FIRESTORM + bli_gks_register_cntx( BLIS_ARCH_FIRESTORM, bli_cntx_init_firestorm, + bli_cntx_init_firestorm_ref, + bli_cntx_init_firestorm_ind ); +#endif #ifdef BLIS_CONFIG_CORTEXA15 bli_gks_register_cntx( BLIS_ARCH_CORTEXA15, bli_cntx_init_cortexa15, bli_cntx_init_cortexa15_ref, @@ -151,6 +181,11 @@ void bli_gks_init( void ) #endif // IBM architectures +#ifdef BLIS_CONFIG_POWER10 + bli_gks_register_cntx( BLIS_ARCH_POWER10, bli_cntx_init_power10, + bli_cntx_init_power10_ref, + bli_cntx_init_power10_ind ); +#endif #ifdef BLIS_CONFIG_POWER9 bli_gks_register_cntx( BLIS_ARCH_POWER9, bli_cntx_init_power9, bli_cntx_init_power9_ref, @@ -235,7 +270,7 @@ void bli_gks_init_index( void ) // architecture id elements of the internal arrays to NULL. const size_t gks_size = sizeof( cntx_t* ) * BLIS_NUM_ARCHS; - const size_t fpa_size = sizeof( void* ) * BLIS_NUM_ARCHS; + const size_t fpa_size = sizeof( void_fp ) * BLIS_NUM_ARCHS; // Set every entry in gks and context init function pointer arrays to // zero/NULL. This is done so that later on we know which ones were @@ -290,14 +325,35 @@ cntx_t* bli_gks_lookup_ind_cntx // ----------------------------------------------------------------------------- +cntx_t** bli_gks_lookup_id + ( + arch_t id + ) +{ + // Return the address of the array of context pointers for a given + // architecture id. This function is only used for sanity check purposes + // to ensure that the underlying data structures for a particular id are + // initialized. + + // Index into the array of context pointers for the given architecture id. + cntx_t** restrict gks_id = gks[ id ]; + + // Return the context pointer at gks_id_ind. + return gks_id; +} + +// ----------------------------------------------------------------------------- + void bli_gks_register_cntx ( - arch_t id, - void* nat_fp, - void* ref_fp, - void* ind_fp + arch_t id, + void_fp nat_fp, + void_fp ref_fp, + void_fp ind_fp ) { + err_t r_val; + // This function is called by bli_gks_init() for each architecture that // will be supported by BLIS. It takes an architecture id and three // function pointers, one to a function that initializes a native context @@ -346,7 +402,7 @@ void bli_gks_register_cntx // needs to be allocated. Allocate the memory and initialize it to // zeros/NULL, storing the address of the alloacted memory at the element // for the current architecture id. - gks[ id ] = bli_calloc_intl( sizeof( cntx_t* ) * BLIS_NUM_IND_METHODS ); + gks[ id ] = bli_calloc_intl( sizeof( cntx_t* ) * BLIS_NUM_IND_METHODS, &r_val ); // Alias the allocated array for readability. cntx_t** restrict gks_id = gks[ id ]; @@ -358,7 +414,7 @@ void bli_gks_register_cntx // Allocate memory for a single context and store the address at // the element in the gks[ id ] array that is reserved for native // execution. - gks_id[ BLIS_NAT ] = bli_calloc_intl( sizeof( cntx_t ) ); + gks_id[ BLIS_NAT ] = bli_calloc_intl( sizeof( cntx_t ), &r_val ); // Alias the allocated context address for readability. cntx_t* restrict gks_id_nat = gks_id[ BLIS_NAT ]; @@ -398,6 +454,11 @@ void bli_gks_register_cntx e_val = bli_check_valid_mc_mod_mult( mc, nr ); bli_check_error_code( e_val ); e_val = bli_check_valid_nc_mod_mult( nc, mr ); bli_check_error_code( e_val ); #endif + + // Verify that the register blocksizes in the context are sufficiently large + // relative to the maximum stack buffer size defined at configure-time. + e_val = bli_check_sufficient_stack_buf_size( gks_id_nat ); + bli_check_error_code( e_val ); } // ----------------------------------------------------------------------------- @@ -455,6 +516,7 @@ cntx_t* bli_gks_query_ind_cntx bli_init_once(); cntx_t* gks_id_ind; + err_t r_val; // Return the address of a context that will be suited for executing a // level-3 operation via the requested induced method (and datatype) for @@ -513,7 +575,7 @@ cntx_t* bli_gks_query_ind_cntx // If gks_id_ind is NULL, then we know we must allocate and then // initialize the context, storing its address back to // gks_id[ ind ]. - gks_id_ind = bli_calloc_intl( sizeof( cntx_t ) ); + gks_id_ind = bli_calloc_intl( sizeof( cntx_t ), &r_val ); gks_id[ ind ] = gks_id_ind; // Before we can call the induced method context initialization @@ -530,7 +592,7 @@ cntx_t* bli_gks_query_ind_cntx // function for the current induced method. (That function assumes // that the context is pre- initialized with values for native // execution.) - f( ind, dt, gks_id_ind ); + f( ind, gks_id_ind ); } } // END CRITICAL SECTION @@ -570,7 +632,7 @@ void bli_gks_init_ref_cntx // ----------------------------------------------------------------------------- -bool_t bli_gks_cntx_l3_nat_ukr_is_ref +bool bli_gks_cntx_l3_nat_ukr_is_ref ( num_t dt, l3ukr_t ukr_id, @@ -585,8 +647,8 @@ bool_t bli_gks_cntx_l3_nat_ukr_is_ref // Query each context for the micro-kernel function pointer for the // specified datatype. - void* ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, &ref_cntx ); - void* fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, cntx ); + void_fp ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, &ref_cntx ); + void_fp fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, cntx ); // Return the result. return fp == ref_fp; @@ -614,7 +676,7 @@ char* bli_gks_l3_ukr_impl_string( l3ukr_t ukr, ind_t method, num_t dt ) // then query the ukernel function pointer for the given datatype from // that context. cntx_t* cntx = bli_gks_query_ind_cntx( method, dt ); - void* fp = bli_cntx_get_l3_vir_ukr_dt( dt, ukr, cntx ); + void_fp fp = bli_cntx_get_l3_vir_ukr_dt( dt, ukr, cntx ); // Check whether the ukernel function pointer is NULL for the given // datatype. If it is NULL, return the string for not applicable. @@ -693,8 +755,8 @@ kimpl_t bli_gks_l3_ukr_impl_type( l3ukr_t ukr, ind_t method, num_t dt ) // Query the native ukernel func_t from both the native and reference // contexts. - void* nat_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, nat_cntx ); - void* ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, &ref_cntx_l ); + void_fp nat_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, nat_cntx ); + void_fp ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, &ref_cntx_l ); if ( nat_fp == ref_fp ) return BLIS_REFERENCE_UKERNEL; else return BLIS_OPTIMIZED_UKERNEL; diff --git a/frame/base/bli_gks.h b/frame/base/bli_gks.h index fde4e4ec01..188dcd5075 100644 --- a/frame/base/bli_gks.h +++ b/frame/base/bli_gks.h @@ -42,18 +42,19 @@ void bli_gks_init_index( void ); cntx_t* bli_gks_lookup_nat_cntx( arch_t id ); cntx_t* bli_gks_lookup_ind_cntx( arch_t id, ind_t ind ); -void bli_gks_register_cntx( arch_t id, void* nat_fp, void* ref_fp, void* ind_fp ); +cntx_t** bli_gks_lookup_id( arch_t id ); +void bli_gks_register_cntx( arch_t id, void_fp nat_fp, void_fp ref_fp, void_fp ind_fp ); BLIS_EXPORT_BLIS cntx_t* bli_gks_query_cntx( void ); BLIS_EXPORT_BLIS cntx_t* bli_gks_query_nat_cntx( void ); cntx_t* bli_gks_query_cntx_noinit( void ); -cntx_t* bli_gks_query_ind_cntx( ind_t ind, num_t dt ); +BLIS_EXPORT_BLIS cntx_t* bli_gks_query_ind_cntx( ind_t ind, num_t dt ); BLIS_EXPORT_BLIS void bli_gks_init_ref_cntx( cntx_t* cntx ); -bool_t bli_gks_cntx_l3_nat_ukr_is_ref( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ); +bool bli_gks_cntx_l3_nat_ukr_is_ref( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ); BLIS_EXPORT_BLIS char* bli_gks_l3_ukr_impl_string( l3ukr_t ukr, ind_t method, num_t dt ); BLIS_EXPORT_BLIS kimpl_t bli_gks_l3_ukr_impl_type( l3ukr_t ukr, ind_t method, num_t dt ); diff --git a/frame/ind/bli_ind.c b/frame/base/bli_ind.c similarity index 76% rename from frame/ind/bli_ind.c rename to frame/base/bli_ind.c index 41419c6ce8..a359e89a38 100644 --- a/frame/ind/bli_ind.c +++ b/frame/base/bli_ind.c @@ -36,11 +36,6 @@ static char* bli_ind_impl_str[BLIS_NUM_IND_METHODS] = { -/* 3mh */ "3mh", -/* 3m1 */ "3m1", -/* 4mh */ "4mh", -/* 4m1b */ "4m1b", -/* 4m1a */ "4m1a", /* 1m */ "1m", /* nat */ "native", }; @@ -49,20 +44,24 @@ static char* bli_ind_impl_str[BLIS_NUM_IND_METHODS] = void bli_ind_init( void ) { - // Enable the default induced method (1m) if one or both complex domain - // gemm micro-kernels are unoptimized in the native context. - // NOTE: Instead of calling bli_gks_query_cntx(), we call // bli_gks_query_cntx_noinit() to avoid the call to bli_init_once(). cntx_t* cntx = bli_gks_query_cntx_noinit(); - bool_t c_is_ref = bli_gks_cntx_l3_nat_ukr_is_ref - ( BLIS_SCOMPLEX, BLIS_GEMM_UKR, cntx ); - bool_t z_is_ref = bli_gks_cntx_l3_nat_ukr_is_ref - ( BLIS_DCOMPLEX, BLIS_GEMM_UKR, cntx ); + // For each precision, enable the default induced method (1m) if both of + // the following conditions are met: + // - the complex domain kernel is the (unoptimized) reference kernel + // - the real domain kernel is NOT the (unoptimized) reference kernel + // The second condition means that BLIS will not bother to use an induced + // method if both the real and complex domain kernels are reference. + + bool s_is_ref = bli_gks_cntx_l3_nat_ukr_is_ref( BLIS_FLOAT, BLIS_GEMM_UKR, cntx ); + bool d_is_ref = bli_gks_cntx_l3_nat_ukr_is_ref( BLIS_DOUBLE, BLIS_GEMM_UKR, cntx ); + bool c_is_ref = bli_gks_cntx_l3_nat_ukr_is_ref( BLIS_SCOMPLEX, BLIS_GEMM_UKR, cntx ); + bool z_is_ref = bli_gks_cntx_l3_nat_ukr_is_ref( BLIS_DCOMPLEX, BLIS_GEMM_UKR, cntx ); - if ( c_is_ref ) bli_ind_enable_dt( BLIS_1M, BLIS_SCOMPLEX ); - if ( z_is_ref ) bli_ind_enable_dt( BLIS_1M, BLIS_DCOMPLEX ); + if ( c_is_ref && !s_is_ref ) bli_ind_enable_dt( BLIS_1M, BLIS_SCOMPLEX ); + if ( z_is_ref && !d_is_ref ) bli_ind_enable_dt( BLIS_1M, BLIS_DCOMPLEX ); } void bli_ind_finalize( void ) @@ -137,14 +136,15 @@ void bli_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ) // ----------------------------------------------------------------------------- -bool_t bli_ind_oper_is_impl( opid_t oper, ind_t method ) +bool bli_ind_oper_is_impl( opid_t oper, ind_t method ) { - bool_t is_impl = FALSE; + bool is_impl = FALSE; if ( bli_opid_is_level3( oper ) ) { - // Look up whether its func_t pointer in the table is NULL. - is_impl = ( bli_l3_ind_oper_get_func( oper, method ) != NULL ); + // Look up whether the operation is implemented for the given induced + // method id. + is_impl = bli_l3_ind_oper_is_impl( oper, method ); } else { @@ -158,39 +158,6 @@ bool_t bli_ind_oper_is_impl( opid_t oper, ind_t method ) return is_impl; } -#if 0 -bool_t bli_ind_oper_has_avail( opid_t oper, num_t dt ) -{ - ind_t method = bli_ind_oper_find_avail( oper, dt ); - - if ( method == BLIS_NAT ) return FALSE; - else return TRUE; -} -#endif - -void* bli_ind_oper_get_avail( opid_t oper, num_t dt ) -{ - void* func_p; - - if ( bli_opid_is_level3( oper ) ) - { - ind_t method = bli_ind_oper_find_avail( oper, dt ); - - func_p = bli_l3_ind_oper_get_func( oper, method ); - } - else - { - // Currently, any operation that is not level-3 does not - // have induced method implementations. (This should actually - // assign the pointer to be the native front-end, but for - // now there are no calls to bli_ind_oper_get_avail() in the - // context of level-2 operations. - func_p = NULL; - } - - return func_p; -} - ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ) { ind_t method; diff --git a/frame/ind/bli_ind.h b/frame/base/bli_ind.h similarity index 66% rename from frame/ind/bli_ind.h rename to frame/base/bli_ind.h index e0f1e9608e..85cad648e9 100644 --- a/frame/ind/bli_ind.h +++ b/frame/base/bli_ind.h @@ -38,34 +38,22 @@ // level-3 induced method management #include "bli_l3_ind.h" -// level-3 object APIs -#include "bli_l3_ind_oapi.h" - -// level-3 typed APIs -#include "bli_l3_ind_tapi.h" - -// level-3 cntx initialization -#include "bli_cntx_ind_stage.h" - - void bli_ind_init( void ); void bli_ind_finalize( void ); -BLIS_EXPORT_BLIS void bli_ind_enable( ind_t method ); -BLIS_EXPORT_BLIS void bli_ind_disable( ind_t method ); -BLIS_EXPORT_BLIS void bli_ind_disable_all( void ); +BLIS_EXPORT_BLIS void bli_ind_enable( ind_t method ); +BLIS_EXPORT_BLIS void bli_ind_disable( ind_t method ); +BLIS_EXPORT_BLIS void bli_ind_disable_all( void ); -BLIS_EXPORT_BLIS void bli_ind_enable_dt( ind_t method, num_t dt ); -BLIS_EXPORT_BLIS void bli_ind_disable_dt( ind_t method, num_t dt ); -BLIS_EXPORT_BLIS void bli_ind_disable_all_dt( num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_enable_dt( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_disable_dt( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_disable_all_dt( num_t dt ); -BLIS_EXPORT_BLIS void bli_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); -BLIS_EXPORT_BLIS bool_t bli_ind_oper_is_impl( opid_t oper, ind_t method ); -//bool_t bli_ind_oper_has_avail( opid_t oper, num_t dt ); -BLIS_EXPORT_BLIS void* bli_ind_oper_get_avail( opid_t oper, num_t dt ); -BLIS_EXPORT_BLIS ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ); -BLIS_EXPORT_BLIS char* bli_ind_oper_get_avail_impl_string( opid_t oper, num_t dt ); +BLIS_EXPORT_BLIS bool bli_ind_oper_is_impl( opid_t oper, ind_t method ); +BLIS_EXPORT_BLIS ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ); +BLIS_EXPORT_BLIS char* bli_ind_oper_get_avail_impl_string( opid_t oper, num_t dt ); char* bli_ind_get_impl_string( ind_t method ); num_t bli_ind_map_cdt_to_index( num_t dt ); diff --git a/frame/base/bli_info.c b/frame/base/bli_info.c index 76844ec239..bfa5ca9a38 100644 --- a/frame/base/bli_info.c +++ b/frame/base/bli_info.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,33 +43,32 @@ static char* bli_version_str = BLIS_VERSION_STRING; static char* bli_int_type_size_str = STRINGIFY_INT( BLIS_INT_TYPE_SIZE ); -char* bli_info_get_version_str( void ) { return bli_version_str; } -char* bli_info_get_int_type_size_str( void ) { return bli_int_type_size_str; } +char* bli_info_get_version_str( void ) { return bli_version_str; } +char* bli_info_get_int_type_size_str( void ) { return bli_int_type_size_str; } // -- General configuration-related -------------------------------------------- -gint_t bli_info_get_int_type_size( void ) { return BLIS_INT_TYPE_SIZE; } -gint_t bli_info_get_num_fp_types( void ) { return BLIS_NUM_FP_TYPES; } -gint_t bli_info_get_max_type_size( void ) { return BLIS_MAX_TYPE_SIZE; } -gint_t bli_info_get_page_size( void ) { return BLIS_PAGE_SIZE; } -gint_t bli_info_get_simd_num_registers( void ) { return BLIS_SIMD_NUM_REGISTERS; } -gint_t bli_info_get_simd_size( void ) { return BLIS_SIMD_SIZE; } -gint_t bli_info_get_simd_align_size( void ) { return BLIS_SIMD_ALIGN_SIZE; } -gint_t bli_info_get_stack_buf_max_size( void ) { return BLIS_STACK_BUF_MAX_SIZE; } -gint_t bli_info_get_stack_buf_align_size( void ) { return BLIS_STACK_BUF_ALIGN_SIZE; } -gint_t bli_info_get_heap_addr_align_size( void ) { return BLIS_HEAP_ADDR_ALIGN_SIZE; } -gint_t bli_info_get_heap_stride_align_size( void ) { return BLIS_HEAP_STRIDE_ALIGN_SIZE; } -gint_t bli_info_get_pool_addr_align_size( void ) { return BLIS_POOL_ADDR_ALIGN_SIZE; } -gint_t bli_info_get_enable_stay_auto_init( void ) -{ -#ifdef BLIS_ENABLE_STAY_AUTO_INITIALIZED - return 1; -#else - return 0; -#endif -} +gint_t bli_info_get_int_type_size( void ) { return BLIS_INT_TYPE_SIZE; } +gint_t bli_info_get_num_fp_types( void ) { return BLIS_NUM_FP_TYPES; } +gint_t bli_info_get_max_type_size( void ) { return BLIS_MAX_TYPE_SIZE; } +gint_t bli_info_get_page_size( void ) { return BLIS_PAGE_SIZE; } +gint_t bli_info_get_simd_num_registers( void ) { return BLIS_SIMD_MAX_NUM_REGISTERS; } +gint_t bli_info_get_simd_size( void ) { return BLIS_SIMD_MAX_SIZE; } +gint_t bli_info_get_simd_align_size( void ) { return BLIS_SIMD_ALIGN_SIZE; } +gint_t bli_info_get_stack_buf_max_size( void ) { return BLIS_STACK_BUF_MAX_SIZE; } +gint_t bli_info_get_stack_buf_align_size( void ) { return BLIS_STACK_BUF_ALIGN_SIZE; } +gint_t bli_info_get_heap_addr_align_size( void ) { return BLIS_HEAP_ADDR_ALIGN_SIZE; } +gint_t bli_info_get_heap_stride_align_size( void ) { return BLIS_HEAP_STRIDE_ALIGN_SIZE; } +gint_t bli_info_get_pool_addr_align_size_a( void ) { return BLIS_POOL_ADDR_ALIGN_SIZE_A; } +gint_t bli_info_get_pool_addr_align_size_b( void ) { return BLIS_POOL_ADDR_ALIGN_SIZE_B; } +gint_t bli_info_get_pool_addr_align_size_c( void ) { return BLIS_POOL_ADDR_ALIGN_SIZE_C; } +gint_t bli_info_get_pool_addr_align_size_gen( void ) { return BLIS_POOL_ADDR_ALIGN_SIZE_GEN; } +gint_t bli_info_get_pool_addr_offset_size_a( void ) { return BLIS_POOL_ADDR_OFFSET_SIZE_A; } +gint_t bli_info_get_pool_addr_offset_size_b( void ) { return BLIS_POOL_ADDR_OFFSET_SIZE_B; } +gint_t bli_info_get_pool_addr_offset_size_c( void ) { return BLIS_POOL_ADDR_OFFSET_SIZE_C; } +gint_t bli_info_get_pool_addr_offset_size_gen( void ) { return BLIS_POOL_ADDR_OFFSET_SIZE_GEN; } gint_t bli_info_get_enable_blas( void ) { #ifdef BLIS_ENABLE_BLAS @@ -181,12 +180,13 @@ char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ) // -- BLIS implementation query (level-3) -------------------------------------- char* bli_info_get_gemm_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_GEMM, dt ); } +char* bli_info_get_gemmt_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_GEMMT, dt ); } char* bli_info_get_hemm_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_HEMM, dt ); } -char* bli_info_get_herk_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_HERK, dt ); } -char* bli_info_get_her2k_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_HER2K, dt ); } +char* bli_info_get_herk_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_GEMMT, dt ); } +char* bli_info_get_her2k_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_GEMMT, dt ); } char* bli_info_get_symm_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_SYMM, dt ); } -char* bli_info_get_syrk_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_SYRK, dt ); } -char* bli_info_get_syr2k_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_SYR2K, dt ); } +char* bli_info_get_syrk_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_GEMMT, dt ); } +char* bli_info_get_syr2k_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_GEMMT, dt ); } char* bli_info_get_trmm_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_TRMM, dt ); } char* bli_info_get_trmm3_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_TRMM3, dt ); } char* bli_info_get_trsm_impl_string( num_t dt ) { return bli_ind_oper_get_avail_impl_string( BLIS_TRSM, dt ); } diff --git a/frame/base/bli_info.h b/frame/base/bli_info.h index be078fd7b3..99c7d000db 100644 --- a/frame/base/bli_info.h +++ b/frame/base/bli_info.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,7 +53,14 @@ BLIS_EXPORT_BLIS gint_t bli_info_get_stack_buf_max_size( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_stack_buf_align_size( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_heap_addr_align_size( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_heap_stride_align_size( void ); -BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_align_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_align_size_a( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_align_size_b( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_align_size_c( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_align_size_gen( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_offset_size_a( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_offset_size_b( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_offset_size_c( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_offset_size_gen( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_stay_auto_init( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_blas( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_cblas( void ); @@ -84,6 +91,7 @@ BLIS_EXPORT_BLIS char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t // -- BLIS implementation query (level-3) -------------------------------------- BLIS_EXPORT_BLIS char* bli_info_get_gemm_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_gemmt_impl_string( num_t dt ); BLIS_EXPORT_BLIS char* bli_info_get_hemm_impl_string( num_t dt ); BLIS_EXPORT_BLIS char* bli_info_get_herk_impl_string( num_t dt ); BLIS_EXPORT_BLIS char* bli_info_get_her2k_impl_string( num_t dt ); diff --git a/frame/base/bli_init.c b/frame/base/bli_init.c index 1180f1c37d..e616ac2d7b 100644 --- a/frame/base/bli_init.c +++ b/frame/base/bli_init.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,18 +56,28 @@ void bli_init_auto( void ) void bli_finalize_auto( void ) { -#ifdef BLIS_ENABLE_STAY_AUTO_INITIALIZED + // The _auto() functions are used when initializing the BLAS compatibility + // layer. It would not make much sense to automatically initialize and + // finalize for every BLAS routine call; therefore, we remain initialized + // unless and until the application explicitly calls bli_finalize(). +} - // If BLIS was configured to stay initialized after being automatically - // initialized, we honor the configuration request and do nothing. - // BLIS will remain initialized unless and until the user explicitly - // calls bli_finalize(). +// ----------------------------------------------------------------------------- -#else +// A pthread_once_t variable is a pthread structure used in pthread_once(). +// pthread_once() is guaranteed to execute exactly once among all threads that +// pass in this control object (until/unless the variable is reset). +static bli_pthread_once_t once_init = BLIS_PTHREAD_ONCE_INIT; +static bli_pthread_once_t once_finalize = BLIS_PTHREAD_ONCE_INIT; - bli_finalize_once(); +void bli_init_once( void ) +{ + bli_pthread_once( &once_init, bli_init_apis ); +} -#endif +void bli_finalize_once( void ) +{ + bli_pthread_once( &once_finalize, bli_finalize_apis ); } // ----------------------------------------------------------------------------- @@ -78,34 +88,35 @@ void bli_init_apis( void ) bli_gks_init(); bli_ind_init(); bli_thread_init(); + bli_pack_init(); bli_memsys_init(); + + // Reset the control variable that will allow finalization. + // NOTE: We must initialize a fresh pthread_once_t object and THEN copy the + // contents to the static control variable because some implementations of + // pthreads define pthread_once_t as a struct and BLIS_PTHREAD_ONCE_INIT as + // a struct initializer expression (i.e. { ... }), which cannot be used in + // post-declaration struct assignment in strict C99. + const bli_pthread_once_t once_new = BLIS_PTHREAD_ONCE_INIT; + once_finalize = once_new; } void bli_finalize_apis( void ) { // Finalize various sub-APIs. bli_memsys_finalize(); + bli_pack_finalize(); bli_thread_finalize(); - bli_gks_finalize(); bli_ind_finalize(); -} - -// ----------------------------------------------------------------------------- - -// A pthread_once_t variable is a pthread structure used in pthread_once(). -// pthread_once() is guaranteed to execute exactly once among all threads that -// pass in this control object. Thus, we need one for initialization and a -// separate one for finalization. -static bli_pthread_once_t once_init = BLIS_PTHREAD_ONCE_INIT; -static bli_pthread_once_t once_finalize = BLIS_PTHREAD_ONCE_INIT; - -void bli_init_once( void ) -{ - bli_pthread_once( &once_init, bli_init_apis ); -} + bli_gks_finalize(); -void bli_finalize_once( void ) -{ - bli_pthread_once( &once_finalize, bli_finalize_apis ); + // Reset the control variable that will allow (re-)initialization. + // NOTE: We must initialize a fresh pthread_once_t object and THEN copy the + // contents to the static control variable because some implementations of + // pthreads define pthread_once_t as a struct and BLIS_PTHREAD_ONCE_INIT as + // a struct initializer expression (i.e. { ... }), which cannot be used in + // post-declaration struct assignment in strict C99. + const bli_pthread_once_t once_new = BLIS_PTHREAD_ONCE_INIT; + once_init = once_new; } diff --git a/frame/base/bli_machval.c b/frame/base/bli_machval.c index e26c5a4d8d..1aaf604d80 100644 --- a/frame/base/bli_machval.c +++ b/frame/base/bli_machval.c @@ -80,7 +80,7 @@ void PASTEMAC(chv,opname) \ { \ static ctype_vr pvals[ BLIS_NUM_MACH_PARAMS ]; \ \ - static bool_t first_time = TRUE; \ + static bool first_time = TRUE; \ \ dim_t val_i = mval - BLIS_MACH_PARAM_FIRST; \ ctype_v* v_cast = v; \ diff --git a/frame/base/bli_malloc.c b/frame/base/bli_malloc.c index 25ebeb1e0b..f1993f62e3 100644 --- a/frame/base/bli_malloc.c +++ b/frame/base/bli_malloc.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -71,7 +71,7 @@ void bli_free_pool( void* p ) // ----------------------------------------------------------------------------- -void* bli_malloc_user( size_t size ) +void* bli_malloc_user( size_t size, err_t* r_val ) { const malloc_ft malloc_fp = BLIS_MALLOC_USER; const size_t align_size = BLIS_HEAP_ADDR_ALIGN_SIZE; @@ -82,7 +82,9 @@ void* bli_malloc_user( size_t size ) fflush( stdout ); #endif - return bli_fmalloc_align( malloc_fp, size, align_size ); + void* p = bli_fmalloc_align( malloc_fp, size, align_size, r_val ); + + return p; } void bli_free_user( void* p ) @@ -97,7 +99,7 @@ void bli_free_user( void* p ) // ----------------------------------------------------------------------------- -void* bli_malloc_intl( size_t size ) +void* bli_malloc_intl( size_t size, err_t* r_val ) { const malloc_ft malloc_fp = BLIS_MALLOC_INTL; @@ -106,18 +108,21 @@ void* bli_malloc_intl( size_t size ) fflush( stdout ); #endif - return bli_fmalloc_noalign( malloc_fp, size ); + void* p = bli_fmalloc_noalign( malloc_fp, size, r_val ); + + return p; } -void* bli_calloc_intl( size_t size ) +void* bli_calloc_intl( size_t size, err_t* r_val ) { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_calloc_intl(): " ); #endif - void* p = bli_malloc_intl( size ); + void* p = bli_malloc_intl( size, r_val ); - memset( p, 0, size ); + if ( bli_is_success( *r_val ) ) + memset( p, 0, size ); return p; } @@ -138,7 +143,8 @@ void* bli_fmalloc_align ( malloc_ft f, size_t size, - size_t align_size + size_t align_size, + err_t* r_val ) { const size_t ptr_size = sizeof( void* ); @@ -165,6 +171,9 @@ void* bli_fmalloc_align if ( bli_error_checking_is_enabled() ) bli_fmalloc_post_check( p_orig ); + // The pseudo-return value isn't used yet. + *r_val = BLIS_SUCCESS; + // Advance the pointer by one pointer element. p_byte = p_orig; p_byte += ptr_size; @@ -226,7 +235,8 @@ void bli_ffree_align void* bli_fmalloc_noalign ( malloc_ft f, - size_t size + size_t size, + err_t* r_val ) { void* p = f( size ); @@ -235,6 +245,9 @@ void* bli_fmalloc_noalign if ( bli_error_checking_is_enabled() ) bli_fmalloc_post_check( p ); + // The pseudo-return value isn't used yet. + *r_val = BLIS_SUCCESS; + return p; } diff --git a/frame/base/bli_malloc.h b/frame/base/bli_malloc.h index e7d523a32f..488124045f 100644 --- a/frame/base/bli_malloc.h +++ b/frame/base/bli_malloc.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,8 +34,8 @@ */ // Typedef function pointer types for malloc() and free() substitutes. -typedef void* (*malloc_ft) ( size_t size ); -typedef void (*free_ft) ( void* p ); +//typedef void* (*malloc_ft) ( size_t size ); +//typedef void (*free_ft) ( void* p ); // ----------------------------------------------------------------------------- @@ -44,19 +44,19 @@ BLIS_EXPORT_BLIS void* bli_malloc_pool( size_t size ); BLIS_EXPORT_BLIS void bli_free_pool( void* p ); #endif -void* bli_malloc_intl( size_t size ); -void* bli_calloc_intl( size_t size ); +void* bli_malloc_intl( size_t size, err_t* r_val ); +void* bli_calloc_intl( size_t size, err_t* r_val ); void bli_free_intl( void* p ); -BLIS_EXPORT_BLIS void* bli_malloc_user( size_t size ); +BLIS_EXPORT_BLIS void* bli_malloc_user( size_t size, err_t* r_val ); BLIS_EXPORT_BLIS void bli_free_user( void* p ); // ----------------------------------------------------------------------------- -void* bli_fmalloc_align( malloc_ft f, size_t size, size_t align_size ); +void* bli_fmalloc_align( malloc_ft f, size_t size, size_t align_size, err_t* r_val ); void bli_ffree_align( free_ft f, void* p ); -void* bli_fmalloc_noalign( malloc_ft f, size_t size ); +void* bli_fmalloc_noalign( malloc_ft f, size_t size, err_t* r_val ); void bli_ffree_noalign( free_ft f, void* p ); void bli_fmalloc_align_check( malloc_ft f, size_t size, size_t align_size ); diff --git a/frame/base/bli_mbool.c b/frame/base/bli_mbool.c index 879cc5b8f6..d0b78dacd8 100644 --- a/frame/base/bli_mbool.c +++ b/frame/base/bli_mbool.c @@ -37,15 +37,16 @@ mbool_t* bli_mbool_create ( - bool_t b_s, - bool_t b_d, - bool_t b_c, - bool_t b_z + bool b_s, + bool b_d, + bool b_c, + bool b_z ) { mbool_t* b; + err_t r_val; - b = ( mbool_t* ) bli_malloc_intl( sizeof(mbool_t) ); + b = ( mbool_t* ) bli_malloc_intl( sizeof( mbool_t ), &r_val ); bli_mbool_init ( @@ -62,10 +63,10 @@ mbool_t* bli_mbool_create void bli_mbool_init ( mbool_t* b, - bool_t b_s, - bool_t b_d, - bool_t b_c, - bool_t b_z + bool b_s, + bool b_d, + bool b_c, + bool b_z ) { bli_mbool_set_dt( b_s, BLIS_FLOAT, b ); diff --git a/frame/base/bli_mbool.h b/frame/base/bli_mbool.h index 4cd4a78fb7..6a989590b2 100644 --- a/frame/base/bli_mbool.h +++ b/frame/base/bli_mbool.h @@ -36,14 +36,14 @@ // mbool_t query -static bool_t bli_mbool_get_dt( num_t dt, mbool_t* mb ) +BLIS_INLINE bool bli_mbool_get_dt( num_t dt, mbool_t* mb ) { - return mb->v[ dt ]; + return ( bool )( mb->v[ dt ] ); } // mbool_t modification -static void bli_mbool_set_dt( bool_t val, num_t dt, mbool_t* mb ) +BLIS_INLINE void bli_mbool_set_dt( bool val, num_t dt, mbool_t* mb ) { mb->v[ dt ] = val; } @@ -52,19 +52,19 @@ static void bli_mbool_set_dt( bool_t val, num_t dt, mbool_t* mb ) mbool_t* bli_mbool_create ( - bool_t b_s, - bool_t b_d, - bool_t b_c, - bool_t b_z + bool b_s, + bool b_d, + bool b_c, + bool b_z ); void bli_mbool_init ( mbool_t* b, - bool_t b_s, - bool_t b_d, - bool_t b_c, - bool_t b_z + bool b_s, + bool b_d, + bool b_c, + bool b_z ); void bli_mbool_free( mbool_t* b ); diff --git a/frame/base/bli_mem.h b/frame/base/bli_mem.h index 5f56f98c0a..d61e970214 100644 --- a/frame/base/bli_mem.h +++ b/frame/base/bli_mem.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,80 +34,127 @@ */ + #ifndef BLIS_MEM_H #define BLIS_MEM_H -// Mem entry query +// mem_t object type (defined in bli_type_defs.h) + +/* +typedef struct mem_s +{ + pblk_t pblk; + packbuf_t buf_type; + pool_t* pool; + siz_t size; +} mem_t; -static pblk_t* bli_mem_pblk( mem_t* mem ) +typedef struct +{ + void* buf; + siz_t block_size; +} pblk_t; +*/ + +// +// -- mem_t query -------------------------------------------------------------- +// + +BLIS_INLINE pblk_t* bli_mem_pblk( mem_t* mem ) { return &(mem->pblk); } -static void* bli_mem_buffer( mem_t* mem ) +BLIS_INLINE void* bli_mem_buffer( mem_t* mem ) { return bli_pblk_buf( bli_mem_pblk( mem ) ); } -static packbuf_t bli_mem_buf_type( mem_t* mem ) +BLIS_INLINE packbuf_t bli_mem_buf_type( mem_t* mem ) { return mem->buf_type; } -static pool_t* bli_mem_pool( mem_t* mem ) +BLIS_INLINE pool_t* bli_mem_pool( mem_t* mem ) { return mem->pool; } -static siz_t bli_mem_size( mem_t* mem ) +BLIS_INLINE siz_t bli_mem_size( mem_t* mem ) { return mem->size; } -static bool_t bli_mem_is_alloc( mem_t* mem ) +BLIS_INLINE bool bli_mem_is_alloc( mem_t* mem ) { - return ( bool_t ) + return ( bool ) ( bli_mem_buffer( mem ) != NULL ); } -static bool_t bli_mem_is_unalloc( mem_t* mem ) +BLIS_INLINE bool bli_mem_is_unalloc( mem_t* mem ) { - return ( bool_t ) + return ( bool ) ( bli_mem_buffer( mem ) == NULL ); } -// Mem entry modification +// +// -- mem_t modification ------------------------------------------------------- +// -static void bli_mem_set_pblk( pblk_t* pblk, mem_t* mem ) +BLIS_INLINE void bli_mem_set_pblk( pblk_t* pblk, mem_t* mem ) { mem->pblk = *pblk; } -static void bli_mem_set_buffer( void* buf, mem_t* mem ) +BLIS_INLINE void bli_mem_set_buffer( void* buf, mem_t* mem ) { bli_pblk_set_buf( buf, &(mem->pblk) ); } -static void bli_mem_set_buf_type( packbuf_t buf_type, mem_t* mem ) +BLIS_INLINE void bli_mem_set_buf_type( packbuf_t buf_type, mem_t* mem ) { mem->buf_type = buf_type; } -static void bli_mem_set_pool( pool_t* pool, mem_t* mem ) +BLIS_INLINE void bli_mem_set_pool( pool_t* pool, mem_t* mem ) { mem->pool = pool; } -static void bli_mem_set_size( siz_t size, mem_t* mem ) +BLIS_INLINE void bli_mem_set_size( siz_t size, mem_t* mem ) { mem->size = size; } -static void bli_mem_clear( mem_t* mem ) +// +// -- mem_t initialization ----------------------------------------------------- +// + +// NOTE: This initializer macro must be updated whenever fields are added or +// removed from the mem_t type definition. An alternative to the initializer is +// calling bli_mem_clear() at runtime. + +#define BLIS_MEM_INITIALIZER \ + { \ + .pblk = BLIS_PBLK_INITIALIZER, \ + .buf_type = -1, \ + .pool = NULL, \ + .size = 0, \ + } \ + +BLIS_INLINE void bli_mem_clear( mem_t* mem ) { bli_mem_set_buffer( NULL, mem ); +#ifdef __cplusplus + const packbuf_t pb = BLIS_BUFFER_FOR_GEN_USE; + // When using C++, which is strongly typed, we avoid use of -1 as a + // packbuf_t value since it will result in a compile-time error. + bli_mem_set_buf_type( pb, mem ); +#else + bli_mem_set_buf_type( ( packbuf_t )-1, mem ); +#endif bli_mem_set_pool( NULL, mem ); bli_mem_set_size( 0, mem ); } diff --git a/frame/base/bli_memsys.c b/frame/base/bli_memsys.c index 888eb764d1..ca3c46f998 100644 --- a/frame/base/bli_memsys.c +++ b/frame/base/bli_memsys.c @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,7 +39,7 @@ void bli_memsys_init( void ) { // Query a native context so we have something to pass into - // bli_membrk_init_pools(). We use BLIS_DOUBLE for the datatype, + // bli_pba_init_pools(). We use BLIS_DOUBLE for the datatype, // but the dt argument is actually only used when initializing // contexts for induced methods. // NOTE: Instead of calling bli_gks_query_cntx(), we call @@ -47,7 +47,7 @@ void bli_memsys_init( void ) cntx_t* cntx_p = bli_gks_query_cntx_noinit(); // Initialize the packing block allocator and its data structures. - bli_membrk_init( cntx_p ); + bli_pba_init( cntx_p ); // Initialize the small block allocator and its data structures. bli_sba_init(); @@ -58,7 +58,7 @@ void bli_memsys_finalize( void ) // Finalize the small block allocator and its data structures. bli_sba_finalize(); - // Finalize the global membrk_t object and its data structures. - bli_membrk_finalize(); + // Finalize the packing block allocator and its data structures. + bli_pba_finalize(); } diff --git a/frame/base/bli_memsys.h b/frame/base/bli_memsys.h index 306819c030..be0d48e35b 100644 --- a/frame/base/bli_memsys.h +++ b/frame/base/bli_memsys.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_obj.c b/frame/base/bli_obj.c index 44fdb1f140..23fbb4cd10 100644 --- a/frame/base/bli_obj.c +++ b/frame/base/bli_obj.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -118,6 +118,11 @@ void bli_obj_create_without_buffer bli_obj_set_offs( 0, 0, obj ); bli_obj_set_diag_offset( 0, obj ); + bli_obj_set_pack_fn( NULL, obj ); + bli_obj_set_pack_params( NULL, obj ); + bli_obj_set_ker_fn( NULL, obj ); + bli_obj_set_ker_params( NULL, obj ); + // Set the internal scalar to 1.0. bli_obj_set_scalar_dt( dt, obj ); s = bli_obj_internal_scalar_buffer( obj ); @@ -147,6 +152,7 @@ void bli_obj_alloc_buffer siz_t elem_size; siz_t buffer_size; void* p; + err_t r_val; bli_init_once(); @@ -195,7 +201,7 @@ void bli_obj_alloc_buffer buffer_size = ( siz_t )n_elem * elem_size; // Allocate the buffer. - p = bli_malloc_user( buffer_size ); + p = bli_malloc_user( buffer_size, &r_val ); // Set individual fields. bli_obj_set_buffer( p, obj ); @@ -355,7 +361,7 @@ void bli_obj_free buf_a = bli_obj_buffer_at_off( a ); - bli_zzsets( 0.0, 0.0, value ); + bli_zzsets( 0.0, 0.0, value ); if ( bli_obj_is_float( a ) ) { @@ -405,7 +411,8 @@ void bli_adjust_strides // matrix). if ( m == 0 || n == 0 ) return; - // Interpret rs = cs = 0 as request for column storage. + // Interpret rs = cs = 0 as request for column storage and -1 as a request + // for row storage. if ( *rs == 0 && *cs == 0 && ( *is == 0 || *is == 1 ) ) { // First we handle the 1x1 scalar case explicitly. @@ -414,8 +421,9 @@ void bli_adjust_strides *rs = 1; *cs = 1; } - // We use column-major storage, except when m == 1, because we don't - // want both strides to be unit. + // We use column-major storage, except when m == 1, in which case we + // use what amounts to row-major storage because we don't want both + // strides to be unit. else if ( m == 1 && n > 1 ) { *rs = n; @@ -445,6 +453,46 @@ void bli_adjust_strides BLIS_HEAP_STRIDE_ALIGN_SIZE ); } } + else if ( *rs == -1 && *cs == -1 && ( *is == 0 || *is == 1 ) ) + { + // First we handle the 1x1 scalar case explicitly. + if ( m == 1 && n == 1 ) + { + *rs = 1; + *cs = 1; + } + // We use row-major storage, except when n == 1, in which case we + // use what amounts to column-major storage because we don't want both + // strides to be unit. + else if ( n == 1 && m > 1 ) + { + *rs = 1; + *cs = m; + } + else + { + *rs = n; + *cs = 1; + } + + // Use default complex storage. + *is = 1; + + // Align the strides depending on the tilt of the matrix. Note that + // scalars are neither row nor column tilted. Also note that alignment + // is only done for rs = cs = -1, and any user-supplied row and column + // strides are preserved. + if ( bli_is_col_tilted( m, n, *rs, *cs ) ) + { + *cs = bli_align_dim_to_size( *cs, elem_size, + BLIS_HEAP_STRIDE_ALIGN_SIZE ); + } + else if ( bli_is_row_tilted( m, n, *rs, *cs ) ) + { + *rs = bli_align_dim_to_size( *rs, elem_size, + BLIS_HEAP_STRIDE_ALIGN_SIZE ); + } + } else if ( *rs == 1 && *cs == 1 ) { // If both strides are unit, this is probably a "lazy" request for a @@ -457,7 +505,7 @@ void bli_adjust_strides // Set the column stride to indicate that this is a column vector // stored in column-major order. This is done for legacy reasons, // because we at one time we had to satisify the error checking - // in the underlying BLAS library, which expects the leading + // in the underlying BLAS library, which expects the leading // dimension to be set to at least m, even if it will never be // used for indexing since it is a vector and thus only has one // column of data. diff --git a/frame/base/bli_obj_scalar.c b/frame/base/bli_obj_scalar.c index 48e255cd29..e28d4fda98 100644 --- a/frame/base/bli_obj_scalar.c +++ b/frame/base/bli_obj_scalar.c @@ -206,12 +206,12 @@ void bli_obj_scalar_reset //bli_obj_scalar_attach( BLIS_NO_CONJUGATE, &BLIS_ONE, a ); } -bool_t bli_obj_scalar_has_nonzero_imag +bool bli_obj_scalar_has_nonzero_imag ( obj_t* a ) { - bool_t r_val = FALSE; + bool r_val = FALSE; num_t dt = bli_obj_scalar_dt( a ); void* scalar_a = bli_obj_internal_scalar_buffer( a ); @@ -234,14 +234,14 @@ bool_t bli_obj_scalar_has_nonzero_imag return r_val; } -bool_t bli_obj_scalar_equals +bool bli_obj_scalar_equals ( obj_t* a, obj_t* beta ) { - obj_t scalar_a; - bool_t r_val; + obj_t scalar_a; + bool r_val; bli_obj_scalar_detach( a, &scalar_a ); diff --git a/frame/base/bli_obj_scalar.h b/frame/base/bli_obj_scalar.h index f655ff46e6..86b699659b 100644 --- a/frame/base/bli_obj_scalar.h +++ b/frame/base/bli_obj_scalar.h @@ -76,12 +76,12 @@ BLIS_EXPORT_BLIS void bli_obj_scalar_reset obj_t* a ); -BLIS_EXPORT_BLIS bool_t bli_obj_scalar_has_nonzero_imag +BLIS_EXPORT_BLIS bool bli_obj_scalar_has_nonzero_imag ( obj_t* a ); -BLIS_EXPORT_BLIS bool_t bli_obj_scalar_equals +BLIS_EXPORT_BLIS bool bli_obj_scalar_equals ( obj_t* a, obj_t* beta diff --git a/frame/base/bli_opid.h b/frame/base/bli_opid.h index 542b6a31e1..b7c547ddf9 100644 --- a/frame/base/bli_opid.h +++ b/frame/base/bli_opid.h @@ -32,9 +32,9 @@ */ -static bool_t bli_opid_is_level3( opid_t opid ) +BLIS_INLINE bool bli_opid_is_level3( opid_t opid ) { - return ( bool_t ) + return ( bool ) ( BLIS_GEMM <= opid && opid <= BLIS_TRSM ); } diff --git a/frame/base/bli_pack.c b/frame/base/bli_pack.c new file mode 100644 index 0000000000..c5ce9cc6c9 --- /dev/null +++ b/frame/base/bli_pack.c @@ -0,0 +1,157 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// The global rntm_t structure. (The definition resides in bli_rntm.c.) +extern rntm_t global_rntm; + +// A mutex to allow synchronous access to global_rntm. (The definition +// resides in bli_rntm.c.) +extern bli_pthread_mutex_t global_rntm_mutex; + +// ----------------------------------------------------------------------------- + +void bli_pack_init( void ) +{ + // Read the environment variables and use them to initialize the + // global runtime object. + bli_pack_init_rntm_from_env( &global_rntm ); +} + +void bli_pack_finalize( void ) +{ +} + +// ----------------------------------------------------------------------------- + +void bli_pack_get_pack_a( bool* pack_a ) +{ + // We must ensure that global_rntm has been initialized. + bli_init_once(); + + *pack_a = bli_rntm_pack_a( &global_rntm ); +} + +// ----------------------------------------------------------------------------- + +void bli_pack_get_pack_b( bool* pack_b ) +{ + // We must ensure that global_rntm has been initialized. + bli_init_once(); + + *pack_b = bli_rntm_pack_b( &global_rntm ); +} + +// ---------------------------------------------------------------------------- + +void bli_pack_set_pack_a( bool pack_a ) +{ + // We must ensure that global_rntm has been initialized. + bli_init_once(); + + // Acquire the mutex protecting global_rntm. + bli_pthread_mutex_lock( &global_rntm_mutex ); + + bli_rntm_set_pack_a( pack_a, &global_rntm ); + + // Release the mutex protecting global_rntm. + bli_pthread_mutex_unlock( &global_rntm_mutex ); +} + +// ---------------------------------------------------------------------------- + +void bli_pack_set_pack_b( bool pack_b ) +{ + // We must ensure that global_rntm has been initialized. + bli_init_once(); + + // Acquire the mutex protecting global_rntm. + bli_pthread_mutex_lock( &global_rntm_mutex ); + + bli_rntm_set_pack_b( pack_b, &global_rntm ); + + // Release the mutex protecting global_rntm. + bli_pthread_mutex_unlock( &global_rntm_mutex ); +} + +// ---------------------------------------------------------------------------- + +void bli_pack_init_rntm_from_env + ( + rntm_t* rntm + ) +{ + // NOTE: We don't need to acquire the global_rntm_mutex here because this + // function is only called from bli_pack_init(), which is only called + // by bli_init_once(). + + bool pack_a; + bool pack_b; + +#if 1 //def BLIS_ENABLE_SELECTIVE_PACKING + + // Try to read BLIS_PACK_A and BLIS_PACK_B. For each variable, default to + // -1 if it is unset. + gint_t pack_a_env = bli_env_get_var( "BLIS_PACK_A", -1 ); + gint_t pack_b_env = bli_env_get_var( "BLIS_PACK_B", -1 ); + + // Enforce the default behavior first, then check for affirmative FALSE, and + // finally assume anything else is TRUE. + if ( pack_a_env == -1 ) pack_a = FALSE; // default behavior + else if ( pack_a_env == 0 ) pack_a = FALSE; // zero is FALSE + else pack_a = TRUE; // anything else is TRUE + + if ( pack_b_env == -1 ) pack_b = FALSE; // default behavior + else if ( pack_b_env == 0 ) pack_b = FALSE; // zero is FALSE + else pack_b = TRUE; // anything else is TRUE + +#else + + pack_a = TRUE; + pack_b = TRUE; + +#endif + + // Save the results back in the runtime object. + bli_rntm_set_pack_a( pack_a, rntm ); + bli_rntm_set_pack_b( pack_b, rntm ); + +#if 0 + printf( "bli_pack_init_rntm_from_env()\n" ); + bli_rntm_print( rntm ); +#endif +} + diff --git a/frame/base/bli_pack.h b/frame/base/bli_pack.h new file mode 100644 index 0000000000..c12740148c --- /dev/null +++ b/frame/base/bli_pack.h @@ -0,0 +1,49 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_PACK_H +#define BLIS_PACK_H + +void bli_pack_init( void ); +void bli_pack_finalize( void ); + +BLIS_EXPORT_BLIS void bli_pack_get_pack_a( bool* pack_a ); +BLIS_EXPORT_BLIS void bli_pack_get_pack_b( bool* pack_b ); +BLIS_EXPORT_BLIS void bli_pack_set_pack_a( bool pack_a ); +BLIS_EXPORT_BLIS void bli_pack_set_pack_b( bool pack_b ); + +void bli_pack_init_rntm_from_env( rntm_t* rntm ); + +#endif + diff --git a/frame/base/bli_param_map.c b/frame/base/bli_param_map.c index de877f686a..d20eece43c 100644 --- a/frame/base/bli_param_map.c +++ b/frame/base/bli_param_map.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,61 +99,8 @@ void bli_param_map_blis_to_netlib_machval( machval_t machval, char* blas_machval // --- BLAS/LAPACK to BLIS mappings -------------------------------------------- -void bli_param_map_netlib_to_blis_side( char side, side_t* blis_side ) -{ - if ( side == 'l' || side == 'L' ) *blis_side = BLIS_LEFT; - else if ( side == 'r' || side == 'R' ) *blis_side = BLIS_RIGHT; - else - { - // Instead of reporting an error to the framework, default to - // an arbitrary value. This is needed because this function is - // called by the BLAS compatibility layer AFTER it has already - // checked errors and called xerbla(). If the application wants - // to override the BLAS compatibility layer's xerbla--which - // responds to errors with abort()--we need to also NOT call - // abort() here, since either way it has already been dealt - // with. - //bli_check_error_code( BLIS_INVALID_SIDE ); - *blis_side = BLIS_LEFT; - } -} - -void bli_param_map_netlib_to_blis_uplo( char uplo, uplo_t* blis_uplo ) -{ - if ( uplo == 'l' || uplo == 'L' ) *blis_uplo = BLIS_LOWER; - else if ( uplo == 'u' || uplo == 'U' ) *blis_uplo = BLIS_UPPER; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_UPLO ); - *blis_uplo = BLIS_LOWER; - } -} - -void bli_param_map_netlib_to_blis_trans( char trans, trans_t* blis_trans ) -{ - if ( trans == 'n' || trans == 'N' ) *blis_trans = BLIS_NO_TRANSPOSE; - else if ( trans == 't' || trans == 'T' ) *blis_trans = BLIS_TRANSPOSE; - else if ( trans == 'c' || trans == 'C' ) *blis_trans = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - *blis_trans = BLIS_NO_TRANSPOSE; - } -} - -void bli_param_map_netlib_to_blis_diag( char diag, diag_t* blis_diag ) -{ - if ( diag == 'n' || diag == 'N' ) *blis_diag = BLIS_NONUNIT_DIAG; - else if ( diag == 'u' || diag == 'U' ) *blis_diag = BLIS_UNIT_DIAG; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_DIAG ); - *blis_diag = BLIS_NONUNIT_DIAG; - } -} +// NOTE: These functions were converted into static functions. Please see this +// file's corresponding header for those definitions. // --- BLIS char to BLIS mappings ---------------------------------------------- diff --git a/frame/base/bli_param_map.h b/frame/base/bli_param_map.h index 9a9601f916..58f179d006 100644 --- a/frame/base/bli_param_map.h +++ b/frame/base/bli_param_map.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -44,10 +45,64 @@ BLIS_EXPORT_BLIS void bli_param_map_blis_to_netlib_machval( machval_t machval, c // --- BLAS/LAPACK to BLIS mappings -------------------------------------------- -BLIS_EXPORT_BLIS void bli_param_map_netlib_to_blis_side( char side, side_t* blis_side ); -BLIS_EXPORT_BLIS void bli_param_map_netlib_to_blis_uplo( char uplo, uplo_t* blis_uplo ); -BLIS_EXPORT_BLIS void bli_param_map_netlib_to_blis_trans( char trans, trans_t* blis_trans ); -BLIS_EXPORT_BLIS void bli_param_map_netlib_to_blis_diag( char diag, diag_t* blis_diag ); +// NOTE: These static functions were converted from regular functions in order +// to reduce function call overhead within the BLAS compatibility layer. + +BLIS_INLINE void bli_param_map_netlib_to_blis_side( char side, side_t* blis_side ) +{ + if ( side == 'l' || side == 'L' ) *blis_side = BLIS_LEFT; + else if ( side == 'r' || side == 'R' ) *blis_side = BLIS_RIGHT; + else + { + // Instead of reporting an error to the framework, default to + // an arbitrary value. This is needed because this function is + // called by the BLAS compatibility layer AFTER it has already + // checked errors and called xerbla(). If the application wants + // to override the BLAS compatibility layer's xerbla--which + // responds to errors with abort()--we need to also NOT call + // abort() here, since either way it has already been dealt + // with. + //bli_check_error_code( BLIS_INVALID_SIDE ); + *blis_side = BLIS_LEFT; + } +} + +BLIS_INLINE void bli_param_map_netlib_to_blis_uplo( char uplo, uplo_t* blis_uplo ) +{ + if ( uplo == 'l' || uplo == 'L' ) *blis_uplo = BLIS_LOWER; + else if ( uplo == 'u' || uplo == 'U' ) *blis_uplo = BLIS_UPPER; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_UPLO ); + *blis_uplo = BLIS_LOWER; + } +} + +BLIS_INLINE void bli_param_map_netlib_to_blis_trans( char trans, trans_t* blis_trans ) +{ + if ( trans == 'n' || trans == 'N' ) *blis_trans = BLIS_NO_TRANSPOSE; + else if ( trans == 't' || trans == 'T' ) *blis_trans = BLIS_TRANSPOSE; + else if ( trans == 'c' || trans == 'C' ) *blis_trans = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + *blis_trans = BLIS_NO_TRANSPOSE; + } +} + +BLIS_INLINE void bli_param_map_netlib_to_blis_diag( char diag, diag_t* blis_diag ) +{ + if ( diag == 'n' || diag == 'N' ) *blis_diag = BLIS_NONUNIT_DIAG; + else if ( diag == 'u' || diag == 'U' ) *blis_diag = BLIS_UNIT_DIAG; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_DIAG ); + *blis_diag = BLIS_NONUNIT_DIAG; + } +} // --- BLIS char to BLIS mappings ---------------------------------------------- diff --git a/frame/base/bli_part.c b/frame/base/bli_part.c index ce6af5b6fe..95587e4a71 100644 --- a/frame/base/bli_part.c +++ b/frame/base/bli_part.c @@ -126,22 +126,11 @@ void bli_acquire_mpart_mdim doff_t diag_off_inc; - // NOTE: Most of this function implicitly assumes moving forward. - // When moving backward, we have to relocate i. - if ( direct == BLIS_BWD ) - { - // Query the dimension in the partitioning direction. - dim_t m = bli_obj_length_after_trans( obj ); - - // Modify i to account for the fact that we are moving backwards. - i = m - i - b; - } - - // Call a special function for partitioning packed objects. (By only // catching those objects packed to panels, we omit cases where the // object is packed to row or column storage, as such objects can be - // partitioned through normally.) + // partitioned through normally.) Note that the function called below + // assumes forward partitioning. if ( bli_obj_is_panel_packed( obj ) ) { bli_packm_acquire_mpart_t2b( req_part, i, b, obj, sub_obj ); @@ -173,6 +162,15 @@ void bli_acquire_mpart_mdim if ( b > m - i ) b = m - i; + // NOTE: Most of this function implicitly assumes moving forward. + // When moving backward, we have to relocate i. + if ( direct == BLIS_BWD ) + { + // Modify i to account for the fact that we are moving backwards. + i = m - i - b; + } + + // Support SUBPART1B (behind SUBPART1) and SUBPART1A (ahead of SUBPART1), // to refer to subpartitions 0 and 2 when moving forward, and 2 and 0 when // moving backward. @@ -268,7 +266,7 @@ void bli_acquire_mpart_mdim // diagonal, then set the subpartition structure to "general"; otherwise // we let the subpartition inherit the storage structure of its immediate // parent. - if ( !bli_obj_root_is_general( sub_obj ) && + if ( !bli_obj_root_is_general( sub_obj ) && bli_obj_is_outside_diag( sub_obj ) ) { // NOTE: This comment may be out-of-date since we now distinguish @@ -276,10 +274,10 @@ void bli_acquire_mpart_mdim // Note that we cannot mark the subpartition object as general/dense // here since it makes sense to preserve the existing uplo information // a while longer so that the correct kernels are invoked. (Example: - // incremental packing/computing in herk produces subpartitions that + // incremental packing/computing in gemmt produces subpartitions that // appear general/dense, but their uplo fields are needed to be either // lower or upper, to determine which macro-kernel gets called in the - // herk_int() back-end.) + // gemmt_int() back-end.) // If the subpartition lies entirely in an "unstored" triangle of the // root matrix, then we need to tweak the subpartition. If the root @@ -352,22 +350,11 @@ void bli_acquire_mpart_ndim doff_t diag_off_inc; - // NOTE: Most of this function implicitly assumes moving forward. - // When moving backward, we have to relocate j. - if ( direct == BLIS_BWD ) - { - // Query the dimension in the partitioning direction. - dim_t n = bli_obj_width_after_trans( obj ); - - // Modify i to account for the fact that we are moving backwards. - j = n - j - b; - } - - // Call a special function for partitioning packed objects. (By only // catching those objects packed to panels, we omit cases where the // object is packed to row or column storage, as such objects can be - // partitioned through normally.) + // partitioned through normally.) Note that the function called below + // assumes forward partitioning. if ( bli_obj_is_panel_packed( obj ) ) { bli_packm_acquire_mpart_l2r( req_part, j, b, obj, sub_obj ); @@ -399,6 +386,15 @@ void bli_acquire_mpart_ndim if ( b > n - j ) b = n - j; + // NOTE: Most of this function implicitly assumes moving forward. + // When moving backward, we have to relocate j. + if ( direct == BLIS_BWD ) + { + // Modify j to account for the fact that we are moving backwards. + j = n - j - b; + } + + // Support SUBPART1B (behind SUBPART1) and SUBPART1A (ahead of SUBPART1), // to refer to subpartitions 0 and 2 when moving forward, and 2 and 0 when // moving backward. @@ -493,7 +489,7 @@ void bli_acquire_mpart_ndim // diagonal), and the subpartition does not intersect the root matrix's // diagonal, then we might need to modify some of the subpartition's // properties, depending on its structure type. - if ( !bli_obj_root_is_general( sub_obj ) && + if ( !bli_obj_root_is_general( sub_obj ) && bli_obj_is_outside_diag( sub_obj ) ) { // NOTE: This comment may be out-of-date since we now distinguish @@ -501,10 +497,10 @@ void bli_acquire_mpart_ndim // Note that we cannot mark the subpartition object as general/dense // here since it makes sense to preserve the existing uplo information // a while longer so that the correct kernels are invoked. (Example: - // incremental packing/computing in herk produces subpartitions that + // incremental packing/computing in gemmt produces subpartitions that // appear general/dense, but their uplo fields are needed to be either // lower or upper, to determine which macro-kernel gets called in the - // herk_int() back-end.) + // gemmt_int() back-end.) // If the subpartition lies entirely in an "unstored" triangle of the // root matrix, then we need to tweak the subpartition. If the root @@ -578,22 +574,11 @@ void bli_acquire_mpart_mndim doff_t diag_off_inc; - // NOTE: Most of this function implicitly assumes moving forward. - // When moving backward, we have to relocate ij. - if ( direct == BLIS_BWD ) - { - // Query the dimension of the object. - dim_t mn = bli_obj_length( obj ); - - // Modify ij to account for the fact that we are moving backwards. - ij = mn - ij - b; - } - - // Call a special function for partitioning packed objects. (By only // catching those objects packed to panels, we omit cases where the // object is packed to row or column storage, as such objects can be - // partitioned through normally.) + // partitioned through normally.) Note that the function called below + // assumes forward partitioning. if ( bli_obj_is_panel_packed( obj ) ) { bli_packm_acquire_mpart_tl2br( req_part, ij, b, obj, sub_obj ); @@ -626,6 +611,15 @@ void bli_acquire_mpart_mndim if ( b > min_m_n - ij ) b = min_m_n - ij; + // NOTE: Most of this function implicitly assumes moving forward. + // When moving backward, we have to relocate ij. + if ( direct == BLIS_BWD ) + { + // Modify ij to account for the fact that we are moving backwards. + ij = min_m_n - ij - b; + } + + // Compute offset increments and dimensions based on which // subpartition is being requested, assuming no transposition. @@ -748,7 +742,7 @@ void bli_acquire_mpart_mndim // diagonal, then set the subpartition structure to "general"; otherwise // we let the subpartition inherit the storage structure of its immediate // parent. - if ( !bli_obj_root_is_general( sub_obj ) && + if ( !bli_obj_root_is_general( sub_obj ) && req_part != BLIS_SUBPART00 && req_part != BLIS_SUBPART11 && req_part != BLIS_SUBPART22 ) @@ -768,10 +762,10 @@ void bli_acquire_mpart_mndim // Note that we cannot mark the subpartition object as general/dense // here since it makes sense to preserve the existing uplo information // a while longer so that the correct kernels are invoked. (Example: - // incremental packing/computing in herk produces subpartitions that + // incremental packing/computing in gemmt produces subpartitions that // appear general/dense, but their uplo fields are needed to be either // lower or upper, to determine which macro-kernel gets called in the - // herk_int() back-end.) + // gemmt_int() back-end.) // If the subpartition lies entirely in an "unstored" triangle of the // root matrix, then we need to tweak the subpartition. If the root diff --git a/frame/base/bli_membrk.c b/frame/base/bli_pba.c similarity index 77% rename from frame/base/bli_membrk.c rename to frame/base/bli_pba.c index 19b50a52f5..f8835e5de0 100644 --- a/frame/base/bli_membrk.c +++ b/frame/base/bli_pba.c @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,54 +36,61 @@ #include "blis.h" -static membrk_t global_membrk; +// Statically initialize the mutex within the packing block allocator object. +static pba_t pba = { .mutex = BLIS_PTHREAD_MUTEX_INITIALIZER }; // ----------------------------------------------------------------------------- -membrk_t* bli_membrk_query( void ) +pba_t* bli_pba_query( void ) { - return &global_membrk; + return &pba; } -void bli_membrk_init +void bli_pba_init ( cntx_t* restrict cntx ) { - membrk_t* restrict membrk = bli_membrk_query(); + pba_t* restrict pba = bli_pba_query(); - const siz_t align_size = BLIS_POOL_ADDR_ALIGN_SIZE; + const siz_t align_size = BLIS_POOL_ADDR_ALIGN_SIZE_GEN; malloc_ft malloc_fp = BLIS_MALLOC_POOL; free_ft free_fp = BLIS_FREE_POOL; - // These fields are used for general-purpose allocation. - bli_membrk_set_align_size( align_size, membrk ); - bli_membrk_set_malloc_fp( malloc_fp, membrk ); - bli_membrk_set_free_fp( free_fp, membrk ); + // These fields are used for general-purpose allocation (ie: buf_type + // equal to BLIS_BUFFER_FOR_GEN_USE) within bli_pba_acquire_m(). + bli_pba_set_align_size( align_size, pba ); + bli_pba_set_malloc_fp( malloc_fp, pba ); + bli_pba_set_free_fp( free_fp, pba ); + + // The mutex field of pba is initialized statically above. This + // keeps bli_pba_init() simpler and removes the possibility of + // something going wrong during mutex initialization. - bli_membrk_init_mutex( membrk ); #ifdef BLIS_ENABLE_PBA_POOLS - bli_membrk_init_pools( cntx, membrk ); + bli_pba_init_pools( cntx, pba ); #endif } -void bli_membrk_finalize +void bli_pba_finalize ( void ) { - membrk_t* restrict membrk = bli_membrk_query(); - - bli_membrk_set_malloc_fp( NULL, membrk ); - bli_membrk_set_free_fp( NULL, membrk ); + pba_t* restrict pba = bli_pba_query(); #ifdef BLIS_ENABLE_PBA_POOLS - bli_membrk_finalize_pools( membrk ); + bli_pba_finalize_pools( pba ); #endif - bli_membrk_finalize_mutex( membrk ); + + // The mutex field of pba is initialized statically above, and + // therefore never destroyed. + + bli_pba_set_malloc_fp( NULL, pba ); + bli_pba_set_free_fp( NULL, pba ); } -void bli_membrk_acquire_m +void bli_pba_acquire_m ( rntm_t* rntm, siz_t req_size, @@ -94,37 +101,38 @@ void bli_membrk_acquire_m pool_t* pool; pblk_t* pblk; dim_t pi; + err_t r_val; // If the internal memory pools for packing block allocator are disabled, // we spoof the buffer type as BLIS_BUFFER_FOR_GEN_USE to induce the - // immediate usage of bli_membrk_malloc(). + // immediate usage of bli_pba_malloc(). #ifndef BLIS_ENABLE_PBA_POOLS buf_type = BLIS_BUFFER_FOR_GEN_USE; #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_membrk_acquire_m(): bli_fmalloc_align(): size %ld\n", + printf( "bli_pba_acquire_m(): bli_fmalloc_align(): size %ld\n", ( long )req_size ); #endif #endif // Query the memory broker from the runtime. - membrk_t* membrk = bli_rntm_membrk( rntm ); + pba_t* pba = bli_rntm_pba( rntm ); if ( buf_type == BLIS_BUFFER_FOR_GEN_USE ) { - malloc_ft malloc_fp = bli_membrk_malloc_fp( membrk ); - siz_t align_size = bli_membrk_align_size( membrk ); + malloc_ft malloc_fp = bli_pba_malloc_fp( pba ); + siz_t align_size = bli_pba_align_size( pba ); // For general-use buffer requests, dynamically allocating memory // is assumed to be sufficient. - void* buf = bli_fmalloc_align( malloc_fp, req_size, align_size ); + void* buf = bli_fmalloc_align( malloc_fp, req_size, align_size, &r_val ); // Initialize the mem_t object with: // - the address of the memory block, // - the buffer type (a packbuf_t value), // - the size of the requested region, - // - the membrk_t from which the mem_t entry was acquired. + // - the pba_t from which the mem_t entry was acquired. // NOTE: We initialize the pool field to NULL since this block did not // come from a memory pool. bli_mem_set_buffer( buf, mem ); @@ -141,13 +149,13 @@ void bli_membrk_acquire_m // Map the requested packed buffer type to a zero-based index, which // we then use to select the corresponding memory pool. pi = bli_packbuf_index( buf_type ); - pool = bli_membrk_pool( pi, membrk ); + pool = bli_pba_pool( pi, pba ); // Extract the address of the pblk_t struct within the mem_t. pblk = bli_mem_pblk( mem ); - // Acquire the mutex associated with the membrk object. - bli_membrk_lock( membrk ); + // Acquire the mutex associated with the pba object. + bli_pba_lock( pba ); // BEGIN CRITICAL SECTION { @@ -165,8 +173,8 @@ void bli_membrk_acquire_m } // END CRITICAL SECTION - // Release the mutex associated with the membrk object. - bli_membrk_unlock( membrk ); + // Release the mutex associated with the pba object. + bli_pba_unlock( pba ); // Query the block_size from the pblk_t. This will be at least // req_size, perhaps larger. @@ -177,7 +185,7 @@ void bli_membrk_acquire_m // - the address of the memory pool to which it belongs, // - the size of the contiguous memory block (NOT the size of the // requested region), - // - the membrk_t from which the mem_t entry was acquired. + // - the pba_t from which the mem_t entry was acquired. // The actual (aligned) address is already stored in the mem_t // struct's pblk_t field. bli_mem_set_buf_type( buf_type, mem ); @@ -187,7 +195,7 @@ void bli_membrk_acquire_m } -void bli_membrk_release +void bli_pba_release ( rntm_t* rntm, mem_t* mem @@ -198,21 +206,21 @@ void bli_membrk_release pblk_t* pblk; // Query the memory broker from the runtime. - membrk_t* membrk = bli_rntm_membrk( rntm ); + pba_t* pba = bli_rntm_pba( rntm ); // Extract the buffer type so we know what kind of memory was allocated. buf_type = bli_mem_buf_type( mem ); #ifndef BLIS_ENABLE_PBA_POOLS #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_membrk_release(): bli_ffree_align(): size %ld\n", + printf( "bli_pba_release(): bli_ffree_align(): size %ld\n", ( long )bli_mem_size( mem ) ); #endif #endif if ( buf_type == BLIS_BUFFER_FOR_GEN_USE ) { - free_ft free_fp = bli_membrk_free_fp( membrk ); + free_ft free_fp = bli_pba_free_fp( pba ); void* buf = bli_mem_buffer( mem ); // For general-use buffers, we dynamically allocate memory, and so @@ -228,8 +236,8 @@ void bli_membrk_release // Extract the address of the pblk_t struct within the mem_t struct. pblk = bli_mem_pblk( mem ); - // Acquire the mutex associated with the membrk object. - bli_membrk_lock( membrk ); + // Acquire the mutex associated with the pba object. + bli_pba_lock( pba ); // BEGIN CRITICAL SECTION { @@ -240,15 +248,15 @@ void bli_membrk_release } // END CRITICAL SECTION - // Release the mutex associated with the membrk object. - bli_membrk_unlock( membrk ); + // Release the mutex associated with the pba object. + bli_pba_unlock( pba ); } // Clear the mem_t object so that it appears unallocated. This clears: // - the pblk_t struct's fields (ie: the buffer addresses) // - the pool field // - the size field - // - the membrk field + // - the pba field // NOTE: We do not clear the buf_type field since there is no // "uninitialized" value for packbuf_t. bli_mem_clear( mem ); @@ -256,35 +264,27 @@ void bli_membrk_release #if 0 -void bli_membrk_acquire_v +void bli_pba_acquire_v ( - membrk_t* membrk, - siz_t req_size, - mem_t* mem + pba_t* pba, + siz_t req_size, + mem_t* mem ) { - bli_membrk_acquire_m( membrk, - req_size, - BLIS_BUFFER_FOR_GEN_USE, - mem ); + bli_pba_acquire_m + ( + pba, + req_size, + BLIS_BUFFER_FOR_GEN_USE, + mem + ); } #endif -void bli_membrk_rntm_set_membrk - ( - rntm_t* rntm - ) -{ - membrk_t* membrk = bli_membrk_query(); - - bli_rntm_set_membrk( membrk, rntm ); -} - - -siz_t bli_membrk_pool_size +siz_t bli_pba_pool_size ( - membrk_t* membrk, + pba_t* pba, packbuf_t buf_type ) { @@ -304,7 +304,7 @@ siz_t bli_membrk_pool_size // Acquire the pointer to the pool corresponding to the buf_type // provided. pool_index = bli_packbuf_index( buf_type ); - pool = bli_membrk_pool( pool_index, membrk ); + pool = bli_pba_pool( pool_index, pba ); // Compute the pool "size" as the product of the block size // and the number of blocks in the pool. @@ -317,10 +317,10 @@ siz_t bli_membrk_pool_size // ----------------------------------------------------------------------------- -void bli_membrk_init_pools +void bli_pba_init_pools ( - cntx_t* cntx, - membrk_t* membrk + cntx_t* cntx, + pba_t* pba ) { // Map each of the packbuf_t values to an index starting at zero. @@ -329,9 +329,9 @@ void bli_membrk_init_pools const dim_t index_c = bli_packbuf_index( BLIS_BUFFER_FOR_C_PANEL ); // Alias the pool addresses to convenient identifiers. - pool_t* pool_a = bli_membrk_pool( index_a, membrk ); - pool_t* pool_b = bli_membrk_pool( index_b, membrk ); - pool_t* pool_c = bli_membrk_pool( index_c, membrk ); + pool_t* pool_a = bli_pba_pool( index_a, pba ); + pool_t* pool_b = bli_pba_pool( index_b, pba ); + pool_t* pool_c = bli_pba_pool( index_c, pba ); // Start with empty pools. const dim_t num_blocks_a = 0; @@ -348,31 +348,38 @@ void bli_membrk_init_pools const dim_t block_ptrs_len_b = 80; const dim_t block_ptrs_len_c = 0; - // Use the address alignment size designated (at configure-time) for pools. - const siz_t align_size = BLIS_POOL_ADDR_ALIGN_SIZE; + // Use the address alignment sizes designated (at configure-time) for pools. + const siz_t align_size_a = BLIS_POOL_ADDR_ALIGN_SIZE_A; + const siz_t align_size_b = BLIS_POOL_ADDR_ALIGN_SIZE_B; + const siz_t align_size_c = BLIS_POOL_ADDR_ALIGN_SIZE_C; + + // Use the offsets from the above alignments. + const siz_t offset_size_a = BLIS_POOL_ADDR_OFFSET_SIZE_A; + const siz_t offset_size_b = BLIS_POOL_ADDR_OFFSET_SIZE_B; + const siz_t offset_size_c = BLIS_POOL_ADDR_OFFSET_SIZE_C; // Use the malloc() and free() designated (at configure-time) for pools. malloc_ft malloc_fp = BLIS_MALLOC_POOL; free_ft free_fp = BLIS_FREE_POOL; // Determine the block size for each memory pool. - bli_membrk_compute_pool_block_sizes( &block_size_a, - &block_size_b, - &block_size_c, - cntx ); + bli_pba_compute_pool_block_sizes( &block_size_a, + &block_size_b, + &block_size_c, + cntx ); // Initialize the memory pools for A, B, and C. - bli_pool_init( num_blocks_a, block_ptrs_len_a, block_size_a, align_size, - malloc_fp, free_fp, pool_a ); - bli_pool_init( num_blocks_b, block_ptrs_len_b, block_size_b, align_size, - malloc_fp, free_fp, pool_b ); - bli_pool_init( num_blocks_c, block_ptrs_len_c, block_size_c, align_size, - malloc_fp, free_fp, pool_c ); + bli_pool_init( num_blocks_a, block_ptrs_len_a, block_size_a, align_size_a, + offset_size_a, malloc_fp, free_fp, pool_a ); + bli_pool_init( num_blocks_b, block_ptrs_len_b, block_size_b, align_size_b, + offset_size_b, malloc_fp, free_fp, pool_b ); + bli_pool_init( num_blocks_c, block_ptrs_len_c, block_size_c, align_size_c, + offset_size_c, malloc_fp, free_fp, pool_c ); } -void bli_membrk_finalize_pools +void bli_pba_finalize_pools ( - membrk_t* membrk + pba_t* pba ) { // Map each of the packbuf_t values to an index starting at zero. @@ -381,9 +388,9 @@ void bli_membrk_finalize_pools dim_t index_c = bli_packbuf_index( BLIS_BUFFER_FOR_C_PANEL ); // Alias the pool addresses to convenient identifiers. - pool_t* pool_a = bli_membrk_pool( index_a, membrk ); - pool_t* pool_b = bli_membrk_pool( index_b, membrk ); - pool_t* pool_c = bli_membrk_pool( index_c, membrk ); + pool_t* pool_a = bli_pba_pool( index_a, pba ); + pool_t* pool_b = bli_pba_pool( index_b, pba ); + pool_t* pool_c = bli_pba_pool( index_c, pba ); // Finalize the memory pools for A, B, and C. bli_pool_finalize( pool_a ); @@ -393,7 +400,7 @@ void bli_membrk_finalize_pools // ----------------------------------------------------------------------------- -void bli_membrk_compute_pool_block_sizes +void bli_pba_compute_pool_block_sizes ( siz_t* bs_a, siz_t* bs_b, @@ -421,11 +428,11 @@ void bli_membrk_compute_pool_block_sizes // Avoid considering induced methods for real datatypes. if ( bli_is_real( dt ) && im != BLIS_NAT ) continue; - bli_membrk_compute_pool_block_sizes_dt( dt, - &bs_dt_a, - &bs_dt_b, - &bs_dt_c, - cntx ); + bli_pba_compute_pool_block_sizes_dt( dt, + &bs_dt_a, + &bs_dt_b, + &bs_dt_c, + cntx ); bs_cand_a = bli_max( bs_dt_a, bs_cand_a ); bs_cand_b = bli_max( bs_dt_b, bs_cand_b ); @@ -440,7 +447,7 @@ void bli_membrk_compute_pool_block_sizes // ----------------------------------------------------------------------------- -void bli_membrk_compute_pool_block_sizes_dt +void bli_pba_compute_pool_block_sizes_dt ( num_t dt, siz_t* bs_a, diff --git a/frame/base/bli_membrk.h b/frame/base/bli_pba.h similarity index 57% rename from frame/base/bli_membrk.h rename to frame/base/bli_pba.h index 4d00eae63d..6431607ec9 100644 --- a/frame/base/bli_membrk.h +++ b/frame/base/bli_pba.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,83 +37,100 @@ #ifndef BLIS_MEMBRK_H #define BLIS_MEMBRK_H -// membrk init +// Packing block allocator (formerly memory broker) -static void bli_membrk_init_mutex( membrk_t* membrk ) +/* +typedef struct pba_s { - bli_pthread_mutex_init( &(membrk->mutex), NULL ); -} + pool_t pools[3]; + bli_pthread_mutex_t mutex; + + // These fields are used for general-purpose allocation. + siz_t align_size; + malloc_ft malloc_fp; + free_ft free_fp; + +} pba_t; +*/ -static void bli_membrk_finalize_mutex( membrk_t* membrk ) -{ - bli_pthread_mutex_destroy( &(membrk->mutex) ); -} -// membrk query +// pba init -static pool_t* bli_membrk_pool( dim_t pool_index, membrk_t* membrk ) +//BLIS_INLINE void bli_pba_init_mutex( pba_t* pba ) +//{ +// bli_pthread_mutex_init( &(pba->mutex), NULL ); +//} + +//BLIS_INLINE void bli_pba_finalize_mutex( pba_t* pba ) +//{ +// bli_pthread_mutex_destroy( &(pba->mutex) ); +//} + +// pba query + +BLIS_INLINE pool_t* bli_pba_pool( dim_t pool_index, pba_t* pba ) { - return &(membrk->pools[ pool_index ]); + return &(pba->pools[ pool_index ]); } -static siz_t bli_membrk_align_size( membrk_t* membrk ) +BLIS_INLINE siz_t bli_pba_align_size( pba_t* pba ) { - return membrk->align_size; + return pba->align_size; } -static malloc_ft bli_membrk_malloc_fp( membrk_t* membrk ) +BLIS_INLINE malloc_ft bli_pba_malloc_fp( pba_t* pba ) { - return membrk->malloc_fp; + return pba->malloc_fp; } -static free_ft bli_membrk_free_fp( membrk_t* membrk ) +BLIS_INLINE free_ft bli_pba_free_fp( pba_t* pba ) { - return membrk->free_fp; + return pba->free_fp; } -// membrk modification +// pba modification -static void bli_membrk_set_align_size( siz_t align_size, membrk_t* membrk ) +BLIS_INLINE void bli_pba_set_align_size( siz_t align_size, pba_t* pba ) { - membrk->align_size = align_size; + pba->align_size = align_size; } -static void bli_membrk_set_malloc_fp( malloc_ft malloc_fp, membrk_t* membrk ) +BLIS_INLINE void bli_pba_set_malloc_fp( malloc_ft malloc_fp, pba_t* pba ) { - membrk->malloc_fp = malloc_fp; + pba->malloc_fp = malloc_fp; } -static void bli_membrk_set_free_fp( free_ft free_fp, membrk_t* membrk ) +BLIS_INLINE void bli_pba_set_free_fp( free_ft free_fp, pba_t* pba ) { - membrk->free_fp = free_fp; + pba->free_fp = free_fp; } -// membrk action +// pba action -static void bli_membrk_lock( membrk_t* membrk ) +BLIS_INLINE void bli_pba_lock( pba_t* pba ) { - bli_pthread_mutex_lock( &(membrk->mutex) ); + bli_pthread_mutex_lock( &(pba->mutex) ); } -static void bli_membrk_unlock( membrk_t* membrk ) +BLIS_INLINE void bli_pba_unlock( pba_t* pba ) { - bli_pthread_mutex_unlock( &(membrk->mutex) ); + bli_pthread_mutex_unlock( &(pba->mutex) ); } // ----------------------------------------------------------------------------- -membrk_t* bli_membrk_query( void ); +BLIS_EXPORT_BLIS pba_t* bli_pba_query( void ); -void bli_membrk_init +void bli_pba_init ( cntx_t* cntx ); -void bli_membrk_finalize +void bli_pba_finalize ( void ); -void bli_membrk_acquire_m +void bli_pba_acquire_m ( rntm_t* rntm, siz_t req_size, @@ -121,43 +138,48 @@ void bli_membrk_acquire_m mem_t* mem ); -void bli_membrk_release +void bli_pba_release ( rntm_t* rntm, mem_t* mem ); -void bli_membrk_rntm_set_membrk +BLIS_INLINE void bli_pba_rntm_set_pba ( rntm_t* rntm - ); + ) +{ + pba_t* pba = bli_pba_query(); + + bli_rntm_set_pba( pba, rntm ); +} -siz_t bli_membrk_pool_size +siz_t bli_pba_pool_size ( - membrk_t* membrk, + pba_t* pba, packbuf_t buf_type ); // ---------------------------------------------------------------------------- -void bli_membrk_init_pools +void bli_pba_init_pools ( - cntx_t* cntx, - membrk_t* membrk + cntx_t* cntx, + pba_t* pba ); -void bli_membrk_finalize_pools +void bli_pba_finalize_pools ( - membrk_t* membrk + pba_t* pba ); -void bli_membrk_compute_pool_block_sizes +void bli_pba_compute_pool_block_sizes ( siz_t* bs_a, siz_t* bs_b, siz_t* bs_c, cntx_t* cntx ); -void bli_membrk_compute_pool_block_sizes_dt +void bli_pba_compute_pool_block_sizes_dt ( num_t dt, siz_t* bs_a, diff --git a/frame/base/bli_pool.c b/frame/base/bli_pool.c index 1821e13268..112ab68e80 100644 --- a/frame/base/bli_pool.c +++ b/frame/base/bli_pool.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,14 +43,22 @@ void bli_pool_init siz_t block_ptrs_len, siz_t block_size, siz_t align_size, + siz_t offset_size, malloc_ft malloc_fp, free_ft free_fp, pool_t* restrict pool ) { + err_t r_val; + // Make sure that block_ptrs_len is at least num_blocks. block_ptrs_len = bli_max( block_ptrs_len, num_blocks ); + // Handle the case where block_ptrs_len is zero, we explicitly set it to 1, + // to avoid any malloc() with zero size, whose behavior is not fixed, and + // also to prevent from falling into any further memory corruption bug. + block_ptrs_len = ( block_ptrs_len == 0 ) ? 1 : block_ptrs_len; + #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_init(): allocating block_ptrs (length %d): ", ( int )block_ptrs_len ); @@ -61,14 +69,14 @@ void bli_pool_init // well as pool blocks? If so, don't forget to s/bli_free_intl/free_fp/g. pblk_t* restrict block_ptrs = - bli_malloc_intl( block_ptrs_len * sizeof( pblk_t ) ); + bli_malloc_intl( block_ptrs_len * sizeof( pblk_t ), &r_val ); // Allocate and initialize each entry in the block_ptrs array. for ( dim_t i = 0; i < num_blocks; ++i ) { #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_pool_init(): allocating block %d of size %d (align %d).\n", - ( int )i, ( int )block_size, ( int )align_size ); + printf( "bli_pool_init(): allocating block %d of size %d (align %d, offset %d).\n", + ( int )i, ( int )block_size, ( int )align_size, ( int )offset_size ); fflush( stdout ); #endif @@ -76,6 +84,7 @@ void bli_pool_init ( block_size, align_size, + offset_size, malloc_fp, &(block_ptrs[i]) ); @@ -99,6 +108,7 @@ void bli_pool_init bli_pool_set_num_blocks( num_blocks, pool ); bli_pool_set_block_size( block_size, pool ); bli_pool_set_align_size( align_size, pool ); + bli_pool_set_offset_size( offset_size, pool ); bli_pool_set_malloc_fp( malloc_fp, pool ); bli_pool_set_free_fp( free_fp, pool ); } @@ -119,6 +129,12 @@ void bli_pool_finalize // Query the total number of blocks currently allocated. const siz_t num_blocks = bli_pool_num_blocks( pool ); + // NOTE: This sanity check has been disabled because bli_pool_reinit() + // is currently implemented in terms of bli_pool_finalize() followed by + // bli_pool_init(). If that _reinit() takes place when some blocks are + // checked out, then we would expect top_index != 0, and therefore this + // check is not universally appropriate. +#if 0 // Query the top_index of the pool. const siz_t top_index = bli_pool_top_index( pool ); @@ -130,17 +146,22 @@ void bli_pool_finalize printf( "bli_pool_finalize(): Implication: not all blocks were checked back in!\n" ); bli_abort(); } +#endif // Query the free() function pointer for the pool. free_ft free_fp = bli_pool_free_fp( pool ); #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_pool_finalize(): freeing %d blocks of size %d (align %d).\n", + printf( "bli_pool_finalize(): freeing %d blocks of size %d (align %d, offset %d).\n", ( int )num_blocks, ( int )bli_pool_block_size( pool ), - ( int )bli_pool_align_size( pool ) ); + ( int )bli_pool_align_size( pool ), + ( int )bli_pool_offset_size( pool ) ); fflush( stdout ); #endif + // Query the offset size of the pool. + const siz_t offset_size = bli_pool_offset_size( pool ); + // Free the individual blocks currently in the pool. for ( dim_t i = 0; i < num_blocks; ++i ) { @@ -148,7 +169,7 @@ void bli_pool_finalize printf( "bli_pool_finalize(): block %d: ", ( int )i ); #endif - bli_pool_free_block( free_fp, &(block_ptrs[i]) ); + bli_pool_free_block( offset_size, free_fp, &(block_ptrs[i]) ); } #ifdef BLIS_ENABLE_MEM_TRACING @@ -169,6 +190,7 @@ void bli_pool_finalize bli_pool_set_top_index( 0, pool ); bli_pool_set_block_size( 0, pool ); bli_pool_set_align_size( 0, pool ); + bli_pool_set_offset_size( 0, pool ); #endif } @@ -178,6 +200,7 @@ void bli_pool_reinit siz_t block_ptrs_len_new, siz_t block_size_new, siz_t align_size_new, + siz_t offset_size_new, pool_t* restrict pool ) { @@ -202,6 +225,7 @@ void bli_pool_reinit block_ptrs_len_new, block_size_new, align_size_new, + offset_size_new, malloc_fp, free_fp, pool @@ -223,6 +247,7 @@ void bli_pool_checkout_block const siz_t num_blocks_new = bli_pool_num_blocks( pool ); const siz_t block_ptrs_len_new = bli_pool_block_ptrs_len( pool ); const siz_t align_size_new = bli_pool_align_size( pool ); + const siz_t offset_size_new = bli_pool_offset_size( pool ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_checkout_block(): old block size %d < req size %d; " @@ -237,6 +262,7 @@ void bli_pool_checkout_block block_ptrs_len_new, req_size, align_size_new, + offset_size_new, pool ); } @@ -293,10 +319,13 @@ void bli_pool_checkin_block // has since been reinitialized to a different (larger) block size. if ( bli_pblk_block_size( block ) != bli_pool_block_size( pool ) ) { + // Query the offset size of the pool. + const siz_t offset_size = bli_pool_offset_size( pool ); + // Query the free() function pointer for the pool. free_ft free_fp = bli_pool_free_fp( pool ); - bli_pool_free_block( free_fp, block ); + bli_pool_free_block( offset_size, free_fp, block ); return; } @@ -308,9 +337,10 @@ void bli_pool_checkin_block #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_checkin_block(): checking in block %d of size %d " - "(align %d).\n", + "(align %d, offset %d).\n", ( int )top_index - 1, ( int )bli_pool_block_size( pool ), - ( int )bli_pool_align_size( pool ) ); + ( int )bli_pool_align_size( pool ), + ( int )bli_pool_offset_size( pool ) ); fflush( stdout ); #endif @@ -327,6 +357,8 @@ void bli_pool_grow pool_t* restrict pool ) { + err_t r_val; + // If the requested increase is zero, return early. if ( num_blocks_add == 0 ) return; @@ -346,7 +378,15 @@ void bli_pool_grow { // To prevent this from happening often, we double the current // length of the block_ptrs array. - const siz_t block_ptrs_len_new = 2 * block_ptrs_len_cur; + // Sanity: make sure that the block_ptrs_len_new will be at least + // num_blocks_new, in case doubling the block_ptrs_len_cur is not enough. + // Example 1: + // - block_ptrs_len_cur == num_blocks_cur == 0 and num_blocks_add = 1 + // - So doubling: 2 * block_ptrs_len_cur = 0, whereas 1 is expected + // Example 2: + // - block_ptrs_len_cur == num_blocks_cur == 10 and num_blocks_add = 30 + // - So doubling: 2 * block_ptrs_len_cur = 20, whereas 40 is expected + const siz_t block_ptrs_len_new = bli_max( (2 * block_ptrs_len_cur), num_blocks_new ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_grow(): growing block_ptrs_len (%d -> %d): ", @@ -361,7 +401,7 @@ void bli_pool_grow // well as pool blocks? If so, don't forget to s/bli_free_intl/free_fp/g. pblk_t* restrict block_ptrs_new = - bli_malloc_intl( block_ptrs_len_new * sizeof( pblk_t ) ); + bli_malloc_intl( block_ptrs_len_new * sizeof( pblk_t ), &r_val ); // Query the top_index of the pool. const siz_t top_index = bli_pool_top_index( pool ); @@ -396,8 +436,9 @@ void bli_pool_grow pblk_t* restrict block_ptrs = bli_pool_block_ptrs( pool ); // Query the block size and alignment size of the pool. - const siz_t block_size = bli_pool_block_size( pool ); - const siz_t align_size = bli_pool_align_size( pool ); + const siz_t block_size = bli_pool_block_size( pool ); + const siz_t align_size = bli_pool_align_size( pool ); + const siz_t offset_size = bli_pool_offset_size( pool ); // Query the malloc() function pointer for the pool. malloc_ft malloc_fp = bli_pool_malloc_fp( pool ); @@ -415,6 +456,7 @@ void bli_pool_grow ( block_size, align_size, + offset_size, malloc_fp, &(block_ptrs[i]) ); @@ -456,13 +498,16 @@ void bli_pool_shrink // Compute the new total number of blocks. const siz_t num_blocks_new = num_blocks - num_blocks_sub; + // Query the offset size of the pool. + const siz_t offset_size = bli_pool_offset_size( pool ); + // Query the free() function pointer for the pool. free_ft free_fp = bli_pool_free_fp( pool ); // Free the individual blocks. for ( dim_t i = num_blocks_new; i < num_blocks; ++i ) { - bli_pool_free_block( free_fp, &(block_ptrs[i]) ); + bli_pool_free_block( offset_size, free_fp, &(block_ptrs[i]) ); } // Update the pool_t struct. @@ -477,22 +522,27 @@ void bli_pool_alloc_block ( siz_t block_size, siz_t align_size, + siz_t offset_size, malloc_ft malloc_fp, pblk_t* restrict block ) { + err_t r_val; + #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_pool_alloc_block(): calling fmalloc_align(): size %d (align %d)\n", - ( int )block_size, ( int )align_size ); + printf( "bli_pool_alloc_block(): calling fmalloc_align(): size %d (align %d, offset %d)\n", + ( int )block_size, ( int )align_size, ( int )offset_size ); fflush( stdout ); #endif // Allocate the block via the bli_fmalloc_align() wrapper, which performs // alignment logic and opaquely saves the original pointer so that it can - // be recovered when it's time to free the block. + // be recovered when it's time to free the block. Note that we have to + // add offset_size to the number of bytes requested since we will skip + // that many bytes at the beginning of the allocated memory. void* restrict buf = - bli_fmalloc_align( malloc_fp, block_size, align_size ); + bli_fmalloc_align( malloc_fp, block_size + offset_size, align_size, &r_val ); #if 0 // NOTE: This code is disabled because it is not needed, since @@ -517,6 +567,9 @@ void bli_pool_alloc_block } #endif + // Advance the pointer by offset_size bytes. + buf = ( void* )( ( char* )buf + offset_size ); + // Save the results in the pblk_t structure. bli_pblk_set_buf( buf, block ); bli_pblk_set_block_size( block_size, block ); @@ -524,6 +577,7 @@ void bli_pool_alloc_block void bli_pool_free_block ( + siz_t offset_size, free_ft free_fp, pblk_t* restrict block ) @@ -538,6 +592,10 @@ void bli_pool_free_block // bli_fmalloc_align() when the block was allocated. void* restrict buf = bli_pblk_buf( block ); + // Undo the pointer advancement by offset_size bytes performed previously + // by bli_pool_alloc_block(). + buf = ( void* )( ( char* )buf - offset_size ); + // Free the block via the bli_ffree_align() wrapper, which recovers the // original pointer that was returned by the pool's malloc() function when // the block was allocated. @@ -555,7 +613,7 @@ void bli_pool_print siz_t num_blocks = bli_pool_num_blocks( pool ); siz_t block_size = bli_pool_block_size( pool ); siz_t align_size = bli_pool_align_size( pool ); - dim_t i; + siz_t offset_size = bli_pool_offset_size( pool ); printf( "pool struct ---------------\n" ); printf( " block_ptrs: %p\n", block_ptrs ); @@ -564,8 +622,10 @@ void bli_pool_print printf( " num_blocks: %d\n", ( int )num_blocks ); printf( " block_size: %d\n", ( int )block_size ); printf( " align_size: %d\n", ( int )align_size ); + printf( " offset_size: %d\n", ( int )offset_size ); printf( " pblks sys align\n" ); - for ( i = 0; i < num_blocks; ++i ) + + for ( dim_t i = 0; i < num_blocks; ++i ) { printf( " %d: %p\n", ( int )i, bli_pblk_buf( &block_ptrs[i] ) ); } diff --git a/frame/base/bli_pool.h b/frame/base/bli_pool.h index 0d39fd7d3d..b4bb23feca 100644 --- a/frame/base/bli_pool.h +++ b/frame/base/bli_pool.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -70,29 +70,43 @@ typedef struct // Pool block query -static void* bli_pblk_buf( pblk_t* pblk ) +BLIS_INLINE void* bli_pblk_buf( pblk_t* pblk ) { return pblk->buf; } -static siz_t bli_pblk_block_size( pblk_t* pblk ) +BLIS_INLINE siz_t bli_pblk_block_size( pblk_t* pblk ) { return pblk->block_size; } // Pool block modification -static void bli_pblk_set_buf( void* buf, pblk_t* pblk ) +BLIS_INLINE void bli_pblk_set_buf( void* buf, pblk_t* pblk ) { pblk->buf = buf; } -static void bli_pblk_set_block_size( siz_t block_size, pblk_t* pblk ) +BLIS_INLINE void bli_pblk_set_block_size( siz_t block_size, pblk_t* pblk ) { pblk->block_size = block_size; } -static void bli_pblk_clear( pblk_t* pblk ) +// +// -- pool block initialization ------------------------------------------------ +// + +// NOTE: This initializer macro must be updated whenever fields are added or +// removed from the pblk_t type definition. An alternative to the initializer is +// calling bli_pblk_clear() at runtime. + +#define BLIS_PBLK_INITIALIZER \ + { \ + .buf = NULL, \ + .block_size = 0, \ + } \ + +BLIS_INLINE void bli_pblk_clear( pblk_t* pblk ) { bli_pblk_set_buf( NULL, pblk ); bli_pblk_set_block_size( 0, pblk ); @@ -101,90 +115,100 @@ static void bli_pblk_clear( pblk_t* pblk ) // Pool entry query -static void* bli_pool_block_ptrs( pool_t* pool ) +BLIS_INLINE void* bli_pool_block_ptrs( pool_t* pool ) { return pool->block_ptrs; } -static siz_t bli_pool_block_ptrs_len( pool_t* pool ) +BLIS_INLINE siz_t bli_pool_block_ptrs_len( pool_t* pool ) { return pool->block_ptrs_len; } -static siz_t bli_pool_num_blocks( pool_t* pool ) +BLIS_INLINE siz_t bli_pool_num_blocks( pool_t* pool ) { return pool->num_blocks; } -static siz_t bli_pool_block_size( pool_t* pool ) +BLIS_INLINE siz_t bli_pool_block_size( pool_t* pool ) { return pool->block_size; } -static siz_t bli_pool_align_size( pool_t* pool ) +BLIS_INLINE siz_t bli_pool_align_size( pool_t* pool ) { return pool->align_size; } -static malloc_ft bli_pool_malloc_fp( pool_t* pool ) +BLIS_INLINE siz_t bli_pool_offset_size( pool_t* pool ) +{ + return pool->offset_size; +} + +BLIS_INLINE malloc_ft bli_pool_malloc_fp( pool_t* pool ) { return pool->malloc_fp; } -static free_ft bli_pool_free_fp( pool_t* pool ) +BLIS_INLINE free_ft bli_pool_free_fp( pool_t* pool ) { return pool->free_fp; } -static siz_t bli_pool_top_index( pool_t* pool ) +BLIS_INLINE siz_t bli_pool_top_index( pool_t* pool ) { return pool->top_index; } -static bool_t bli_pool_is_exhausted( pool_t* pool ) +BLIS_INLINE bool bli_pool_is_exhausted( pool_t* pool ) { - return ( bool_t ) + return ( bool ) ( bli_pool_top_index( pool ) == bli_pool_num_blocks( pool ) ); } // Pool entry modification -static void bli_pool_set_block_ptrs( void* block_ptrs, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_block_ptrs( void* block_ptrs, pool_t* pool ) \ { pool->block_ptrs = block_ptrs; } -static void bli_pool_set_block_ptrs_len( siz_t block_ptrs_len, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_block_ptrs_len( siz_t block_ptrs_len, pool_t* pool ) \ { pool->block_ptrs_len = block_ptrs_len; } -static void bli_pool_set_num_blocks( siz_t num_blocks, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_num_blocks( siz_t num_blocks, pool_t* pool ) \ { pool->num_blocks = num_blocks; } -static void bli_pool_set_block_size( siz_t block_size, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_block_size( siz_t block_size, pool_t* pool ) \ { pool->block_size = block_size; } -static void bli_pool_set_align_size( siz_t align_size, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_align_size( siz_t align_size, pool_t* pool ) \ { pool->align_size = align_size; } -static void bli_pool_set_malloc_fp( malloc_ft malloc_fp, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_offset_size( siz_t offset_size, pool_t* pool ) \ +{ + pool->offset_size = offset_size; +} + +BLIS_INLINE void bli_pool_set_malloc_fp( malloc_ft malloc_fp, pool_t* pool ) \ { pool->malloc_fp = malloc_fp; } -static void bli_pool_set_free_fp( free_ft free_fp, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_free_fp( free_ft free_fp, pool_t* pool ) \ { pool->free_fp = free_fp; } -static void bli_pool_set_top_index( siz_t top_index, pool_t* pool ) \ +BLIS_INLINE void bli_pool_set_top_index( siz_t top_index, pool_t* pool ) \ { pool->top_index = top_index; } @@ -197,6 +221,7 @@ void bli_pool_init siz_t block_ptrs_len, siz_t block_size, siz_t align_size, + siz_t offset_size, malloc_ft malloc_fp, free_ft free_fp, pool_t* restrict pool @@ -211,6 +236,7 @@ void bli_pool_reinit siz_t block_ptrs_len_new, siz_t block_size_new, siz_t align_size_new, + siz_t offset_size_new, pool_t* restrict pool ); @@ -241,11 +267,13 @@ void bli_pool_alloc_block ( siz_t block_size, siz_t align_size, + siz_t offset_size, malloc_ft malloc_fp, pblk_t* restrict block ); void bli_pool_free_block ( + siz_t offset_size, free_ft free_fp, pblk_t* restrict block ); diff --git a/frame/base/bli_prune.c b/frame/base/bli_prune.c index 080f66f264..ebe5c23653 100644 --- a/frame/base/bli_prune.c +++ b/frame/base/bli_prune.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_query.c b/frame/base/bli_query.c index a81d1e63c4..c62a30cccd 100644 --- a/frame/base/bli_query.c +++ b/frame/base/bli_query.c @@ -34,13 +34,13 @@ #include "blis.h" -bool_t bli_obj_equals( obj_t* a, - obj_t* b ) +bool bli_obj_equals( obj_t* a, obj_t* b ) { - bool_t r_val = FALSE; - num_t dt_a; - num_t dt_b; - num_t dt; +#if 0 + bool r_val = FALSE; + num_t dt_a; + num_t dt_b; + num_t dt; // The function is not yet implemented for vectors and matrices. if ( !bli_obj_is_1x1( a ) || @@ -81,15 +81,26 @@ bool_t bli_obj_equals( obj_t* a, } return r_val; +#else + bool r_val; + + if ( bli_obj_is_1x1( a ) && bli_obj_is_1x1( b ) ) + bli_eqsc( a, b, &r_val ); + else if ( bli_obj_is_vector( a ) && bli_obj_is_vector( b ) ) + bli_eqv( a, b, &r_val ); + else + bli_eqm( a, b, &r_val ); + + return r_val; +#endif } -bool_t bli_obj_imag_equals( obj_t* a, - obj_t* b ) +bool bli_obj_imag_equals( obj_t* a, obj_t* b ) { #if 0 - bool_t r_val = FALSE; - num_t dt_a; - num_t dt_b; + bool r_val = FALSE; + num_t dt_a; + num_t dt_b; dt_a = bli_obj_dt( a ); dt_b = bli_obj_dt( b ); @@ -130,7 +141,7 @@ bool_t bli_obj_imag_equals( obj_t* a, } } #endif - bool_t r_val = FALSE; + bool r_val = FALSE; // The function is not yet implemented for vectors and matrices. if ( !bli_obj_is_1x1( a ) || @@ -154,9 +165,9 @@ bool_t bli_obj_imag_equals( obj_t* a, return r_val; } -bool_t bli_obj_imag_is_zero( obj_t* a ) +bool bli_obj_imag_is_zero( obj_t* a ) { - bool_t r_val = TRUE; + bool r_val = TRUE; // The function is not yet implemented for vectors and matrices. if ( !bli_obj_is_1x1( a ) ) diff --git a/frame/base/bli_query.h b/frame/base/bli_query.h index 2bb5b3f6b9..65246050b5 100644 --- a/frame/base/bli_query.h +++ b/frame/base/bli_query.h @@ -32,8 +32,8 @@ */ -BLIS_EXPORT_BLIS bool_t bli_obj_equals( obj_t* a, obj_t* b ); +BLIS_EXPORT_BLIS bool bli_obj_equals( obj_t* a, obj_t* b ); -BLIS_EXPORT_BLIS bool_t bli_obj_imag_equals( obj_t* a, obj_t* b ); +BLIS_EXPORT_BLIS bool bli_obj_imag_equals( obj_t* a, obj_t* b ); -BLIS_EXPORT_BLIS bool_t bli_obj_imag_is_zero( obj_t* a ); +BLIS_EXPORT_BLIS bool bli_obj_imag_is_zero( obj_t* a ); diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 34d7413249..a6ded35b32 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -34,6 +34,29 @@ #include "blis.h" +// The global rntm_t structure, which holds the global thread settings +// along with a few other key parameters. +rntm_t global_rntm; + +// A mutex to allow synchronous access to global_rntm. +bli_pthread_mutex_t global_rntm_mutex = BLIS_PTHREAD_MUTEX_INITIALIZER; + +// ---------------------------------------------------------------------------- + +void bli_rntm_init_from_global( rntm_t* rntm ) +{ + // We must ensure that global_rntm has been initialized. + bli_init_once(); + + // Acquire the mutex protecting global_rntm. + bli_pthread_mutex_lock( &global_rntm_mutex ); + + *rntm = global_rntm; + + // Release the mutex protecting global_rntm. + bli_pthread_mutex_unlock( &global_rntm_mutex ); +} + // ----------------------------------------------------------------------------- void bli_rntm_set_ways_for_op @@ -146,14 +169,18 @@ void bli_rntm_set_ways_from_rntm dim_t jr = bli_rntm_jr_ways( rntm ); dim_t ir = bli_rntm_ir_ways( rntm ); + bool auto_factor = FALSE; + #ifdef BLIS_ENABLE_MULTITHREADING - bool_t nt_set = FALSE; - bool_t ways_set = FALSE; + bool nt_set = FALSE; + bool ways_set = FALSE; // If the rntm was fed in as a copy of the global runtime via - // bli_thread_init_rntm(), we know that either the num_threads - // field will be set and all of the ways unset, or vice versa. + // bli_rntm_init_from_global(), we know that either: + // - the num_threads field is -1 and all of the ways are -1; + // - the num_threads field is -1 and all of the ways are set; + // - the num_threads field is set and all of the ways are -1. // However, we can't be sure that a user-provided rntm_t isn't // initialized uncleanly. So here we have to enforce some rules // to get the rntm_t into a predictable state. @@ -161,6 +188,9 @@ void bli_rntm_set_ways_from_rntm // First, we establish whether or not the number of threads is set. if ( nt > 0 ) nt_set = TRUE; + // Take this opportunity to set the auto_factor field. + if ( nt_set ) auto_factor = TRUE; + // Next, we establish whether or not any of the ways of parallelism // for each loop were set. If any of the ways are set (positive), we // then we assume the user wanted to use those positive values and @@ -190,15 +220,28 @@ void bli_rntm_set_ways_from_rntm } else if ( ways_set == FALSE && nt_set == TRUE ) { - // If the ways were not set but the number of threas was set, then + // If the ways were not set but the number of thread was set, then // we attempt to automatically generate a thread factorization that - // will work given the problem size. Thus, here we only set the - // ways and leave the number of threads unchanged. + // will work given the problem size. + +#ifdef BLIS_DISABLE_AUTO_PRIME_NUM_THREADS + // If use of prime numbers is disallowed for automatic thread + // factorizations, we first check if the number of threads requested + // is prime. If it is prime, and it exceeds a minimum threshold, then + // we reduce the number of threads by one so that the number is not + // prime. This will allow for automatic thread factorizations to span + // two dimensions (loops), which tends to be more efficient. + if ( bli_is_prime( nt ) && BLIS_NT_MAX_PRIME < nt ) nt -= 1; +#endif pc = 1; - bli_partition_2x2( nt, m*BLIS_THREAD_RATIO_M, - n*BLIS_THREAD_RATIO_N, &ic, &jc ); + //printf( "m n = %d %d BLIS_THREAD_RATIO_M _N = %d %d\n", (int)m, (int)n, (int)BLIS_THREAD_RATIO_M, (int)BLIS_THREAD_RATIO_N ); + + bli_thread_partition_2x2( nt, m*BLIS_THREAD_RATIO_M, + n*BLIS_THREAD_RATIO_N, &ic, &jc ); + + //printf( "jc ic = %d %d\n", (int)jc, (int)ic ); for ( ir = BLIS_THREAD_MAX_IR ; ir > 1 ; ir-- ) { @@ -230,6 +273,137 @@ void bli_rntm_set_ways_from_rntm #endif // Save the results back in the runtime object. + bli_rntm_set_auto_factor_only( auto_factor, rntm ); + bli_rntm_set_num_threads_only( nt, rntm ); + bli_rntm_set_ways_only( jc, pc, ic, jr, ir, rntm ); +} + +void bli_rntm_set_ways_from_rntm_sup + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) +{ + dim_t nt = bli_rntm_num_threads( rntm ); + + dim_t jc = bli_rntm_jc_ways( rntm ); + dim_t pc = bli_rntm_pc_ways( rntm ); + dim_t ic = bli_rntm_ic_ways( rntm ); + dim_t jr = bli_rntm_jr_ways( rntm ); + dim_t ir = bli_rntm_ir_ways( rntm ); + + bool auto_factor = FALSE; + +#ifdef BLIS_ENABLE_MULTITHREADING + + bool nt_set = FALSE; + bool ways_set = FALSE; + + // If the rntm was fed in as a copy of the global runtime via + // bli_rntm_init_from_global(), we know that either: + // - the num_threads field is -1 and all of the ways are -1; + // - the num_threads field is -1 and all of the ways are set; + // - the num_threads field is set and all of the ways are -1. + // However, we can't be sure that a user-provided rntm_t isn't + // initialized uncleanly. So here we have to enforce some rules + // to get the rntm_t into a predictable state. + + // First, we establish whether or not the number of threads is set. + if ( nt > 0 ) nt_set = TRUE; + + // Take this opportunity to set the auto_factor field. + if ( nt_set ) auto_factor = TRUE; + + // Next, we establish whether or not any of the ways of parallelism + // for each loop were set. If any of the ways are set (positive), we + // then we assume the user wanted to use those positive values and + // default the non-positive values to 1. + if ( jc > 0 || pc > 0 || ic > 0 || jr > 0 || ir > 0 ) + { + ways_set = TRUE; + + if ( jc < 1 ) jc = 1; + if ( pc < 1 ) pc = 1; + if ( ic < 1 ) ic = 1; + if ( jr < 1 ) jr = 1; + if ( ir < 1 ) ir = 1; + } + + // Now we use the values of nt_set and ways_set to determine how to + // interpret the original values we found in the rntm_t object. + + if ( ways_set == TRUE ) + { + // If the ways were set, then we use the values that were given + // and interpreted above (we set any non-positive value to 1). + // The only thing left to do is calculate the correct number of + // threads. + + nt = jc * pc * ic * jr * ir; + } + else if ( ways_set == FALSE && nt_set == TRUE ) + { + // If the ways were not set but the number of thread was set, then + // we attempt to automatically generate a thread factorization that + // will work given the problem size. + +#ifdef BLIS_DISABLE_AUTO_PRIME_NUM_THREADS + // If use of prime numbers is disallowed for automatic thread + // factorizations, we first check if the number of threads requested + // is prime. If it is prime, and it exceeds a minimum threshold, then + // we reduce the number of threads by one so that the number is not + // prime. This will allow for automatic thread factorizations to span + // two dimensions (loops), which tends to be more efficient. + if ( bli_is_prime( nt ) && BLIS_NT_MAX_PRIME < nt ) nt -= 1; +#endif + + pc = 1; + + //bli_thread_partition_2x2( nt, m*BLIS_THREAD_SUP_RATIO_M, + // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); + bli_thread_partition_2x2( nt, m, + n, &ic, &jc ); + +//printf( "bli_rntm_set_ways_from_rntm_sup(): jc = %d ic = %d\n", (int)jc, (int)ic ); +#if 0 + for ( ir = BLIS_THREAD_SUP_MAX_IR ; ir > 1 ; ir-- ) + { + if ( ic % ir == 0 ) { ic /= ir; break; } + } + + for ( jr = BLIS_THREAD_SUP_MAX_JR ; jr > 1 ; jr-- ) + { + if ( jc % jr == 0 ) { jc /= jr; break; } + } +#else + ir = 1; + jr = 1; + +#endif + } + else // if ( ways_set == FALSE && nt_set == FALSE ) + { + // If neither the ways nor the number of threads were set, then + // the rntm was not meaningfully changed since initialization, + // and thus we'll default to single-threaded execution. + + nt = 1; + jc = pc = ic = jr = ir = 1; + } + +#else + + // When multithreading is disabled, always set the rntm_t ways + // values to 1. + nt = 1; + jc = pc = ic = jr = ir = 1; + +#endif + + // Save the results back in the runtime object. + bli_rntm_set_auto_factor_only( auto_factor, rntm ); bli_rntm_set_num_threads_only( nt, rntm ); bli_rntm_set_ways_only( jc, pc, ic, jr, ir, rntm ); } @@ -239,6 +413,8 @@ void bli_rntm_print rntm_t* rntm ) { + dim_t af = bli_rntm_auto_factor( rntm ); + dim_t nt = bli_rntm_num_threads( rntm ); dim_t jc = bli_rntm_jc_ways( rntm ); @@ -248,7 +424,72 @@ void bli_rntm_print dim_t ir = bli_rntm_ir_ways( rntm ); printf( "rntm contents nt jc pc ic jr ir\n" ); - printf( " %4d%4d%4d%4d%4d%4d\n", (int)nt, (int)jc, (int)pc, + printf( "autofac? %1d | %4d%4d%4d%4d%4d%4d\n", (int)af, + (int)nt, (int)jc, (int)pc, (int)ic, (int)jr, (int)ir ); } +// ----------------------------------------------------------------------------- + +dim_t bli_rntm_calc_num_threads_in + ( + bszid_t* restrict bszid_cur, + rntm_t* restrict rntm + ) +{ + /* // bp algorithm: + bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop + BLIS_KC, // level 1: 4th loop + BLIS_NO_PART, // level 2: pack B + BLIS_MC, // level 3: 3rd loop + BLIS_NO_PART, // level 4: pack A + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + + ... // pb algorithm: + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + }; */ + dim_t n_threads_in = 1; + + // Starting with the current element of the bszids array (pointed + // to by bszid_cur), multiply all of the corresponding ways of + // parallelism. + for ( ; *bszid_cur != BLIS_KR; bszid_cur++ ) + { + const bszid_t bszid = *bszid_cur; + + //if ( bszid == BLIS_KR ) break; + + // We assume bszid is in {NC,KC,MC,NR,MR,KR} if it is not + // BLIS_NO_PART. + if ( bszid != BLIS_NO_PART ) + { + const dim_t cur_way = bli_rntm_ways_for( bszid, rntm ); + + n_threads_in *= cur_way; + } + } + + return n_threads_in; +} + +#if 0 + for ( ; *bszid_cur != BLIS_KR; bszid_cur++ ) + { + const bszid_t bszid = *bszid_cur; + dim_t cur_way = 1; + + // We assume bszid is in {NC,KC,MC,NR,MR,KR} if it is not + // BLIS_NO_PART. + if ( bszid != BLIS_NO_PART ) + cur_way = bli_rntm_ways_for( bszid, rntm ); + else + cur_way = 1; + + n_threads_in *= cur_way; + } +#endif + diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index f33c25e365..249a698051 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,8 +43,17 @@ /* typedef struct rntm_s { + bool auto_factor; + dim_t num_threads; dim_t* thrloop; + bool pack_a; + bool pack_b; + bool l3_sup; + + pool_t* sba_pool; + pba_t* pba; + } rntm_t; */ @@ -52,109 +61,135 @@ typedef struct rntm_s // -- rntm_t query (public API) ------------------------------------------------ // -static dim_t bli_rntm_num_threads( rntm_t* rntm ) +BLIS_INLINE bool bli_rntm_auto_factor( rntm_t* rntm ) +{ + return rntm->auto_factor; +} + +BLIS_INLINE dim_t bli_rntm_num_threads( rntm_t* rntm ) { return rntm->num_threads; } -static dim_t bli_rntm_ways_for( bszid_t bszid, rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_ways_for( bszid_t bszid, rntm_t* rntm ) { return rntm->thrloop[ bszid ]; } -static dim_t bli_rntm_jc_ways( rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_jc_ways( rntm_t* rntm ) { return bli_rntm_ways_for( BLIS_NC, rntm ); } -static dim_t bli_rntm_pc_ways( rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_pc_ways( rntm_t* rntm ) { return bli_rntm_ways_for( BLIS_KC, rntm ); } -static dim_t bli_rntm_ic_ways( rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_ic_ways( rntm_t* rntm ) { return bli_rntm_ways_for( BLIS_MC, rntm ); } -static dim_t bli_rntm_jr_ways( rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_jr_ways( rntm_t* rntm ) { return bli_rntm_ways_for( BLIS_NR, rntm ); } -static dim_t bli_rntm_ir_ways( rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_ir_ways( rntm_t* rntm ) { return bli_rntm_ways_for( BLIS_MR, rntm ); } -static dim_t bli_rntm_pr_ways( rntm_t* rntm ) +BLIS_INLINE dim_t bli_rntm_pr_ways( rntm_t* rntm ) { return bli_rntm_ways_for( BLIS_KR, rntm ); } +BLIS_INLINE bool bli_rntm_pack_a( rntm_t* rntm ) +{ + return ( bool )( rntm->pack_a ); +} +BLIS_INLINE bool bli_rntm_pack_b( rntm_t* rntm ) +{ + return ( bool )( rntm->pack_b ); +} + +BLIS_INLINE bool bli_rntm_l3_sup( rntm_t* rntm ) +{ + return rntm->l3_sup; +} + // // -- rntm_t query (internal use only) ----------------------------------------- // -static pool_t* bli_rntm_sba_pool( rntm_t* rntm ) +BLIS_INLINE pool_t* bli_rntm_sba_pool( rntm_t* rntm ) { return rntm->sba_pool; } -static membrk_t* bli_rntm_membrk( rntm_t* rntm ) +BLIS_INLINE pba_t* bli_rntm_pba( rntm_t* rntm ) { - return rntm->membrk; + return rntm->pba; } -static dim_t bli_rntm_equals( rntm_t* rntm1, rntm_t* rntm2 ) +#if 0 +BLIS_INLINE dim_t bli_rntm_equals( rntm_t* rntm1, rntm_t* rntm2 ) { - const bool_t nt = bli_rntm_num_threads( rntm1 ) == bli_rntm_num_threads( rntm2 ); - const bool_t jc = bli_rntm_jc_ways( rntm1 ) == bli_rntm_jc_ways( rntm2 ); - const bool_t pc = bli_rntm_pc_ways( rntm1 ) == bli_rntm_pc_ways( rntm2 ); - const bool_t ic = bli_rntm_ic_ways( rntm1 ) == bli_rntm_ic_ways( rntm2 ); - const bool_t jr = bli_rntm_jr_ways( rntm1 ) == bli_rntm_jr_ways( rntm2 ); - const bool_t ir = bli_rntm_ir_ways( rntm1 ) == bli_rntm_ir_ways( rntm2 ); - const bool_t pr = bli_rntm_pr_ways( rntm1 ) == bli_rntm_pr_ways( rntm2 ); + const bool nt = bli_rntm_num_threads( rntm1 ) == bli_rntm_num_threads( rntm2 ); + const bool jc = bli_rntm_jc_ways( rntm1 ) == bli_rntm_jc_ways( rntm2 ); + const bool pc = bli_rntm_pc_ways( rntm1 ) == bli_rntm_pc_ways( rntm2 ); + const bool ic = bli_rntm_ic_ways( rntm1 ) == bli_rntm_ic_ways( rntm2 ); + const bool jr = bli_rntm_jr_ways( rntm1 ) == bli_rntm_jr_ways( rntm2 ); + const bool ir = bli_rntm_ir_ways( rntm1 ) == bli_rntm_ir_ways( rntm2 ); + const bool pr = bli_rntm_pr_ways( rntm1 ) == bli_rntm_pr_ways( rntm2 ); if ( nt && jc && pc && ic && jr && ir && pr ) return TRUE; else return FALSE; } +#endif // // -- rntm_t modification (internal use only) ---------------------------------- // -static void bli_rntm_set_num_threads_only( dim_t nt, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_auto_factor_only( bool auto_factor, rntm_t* rntm ) +{ + rntm->auto_factor = auto_factor; +} + +BLIS_INLINE void bli_rntm_set_num_threads_only( dim_t nt, rntm_t* rntm ) { rntm->num_threads = nt; } -static void bli_rntm_set_ways_for_only( bszid_t loop, dim_t n_ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_ways_for_only( bszid_t loop, dim_t n_ways, rntm_t* rntm ) { rntm->thrloop[ loop ] = n_ways; } -static void bli_rntm_set_jc_ways_only( dim_t ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_jc_ways_only( dim_t ways, rntm_t* rntm ) { bli_rntm_set_ways_for_only( BLIS_NC, ways, rntm ); } -static void bli_rntm_set_pc_ways_only( dim_t ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_pc_ways_only( dim_t ways, rntm_t* rntm ) { bli_rntm_set_ways_for_only( BLIS_KC, ways, rntm ); } -static void bli_rntm_set_ic_ways_only( dim_t ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_ic_ways_only( dim_t ways, rntm_t* rntm ) { bli_rntm_set_ways_for_only( BLIS_MC, ways, rntm ); } -static void bli_rntm_set_jr_ways_only( dim_t ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_jr_ways_only( dim_t ways, rntm_t* rntm ) { bli_rntm_set_ways_for_only( BLIS_NR, ways, rntm ); } -static void bli_rntm_set_ir_ways_only( dim_t ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_ir_ways_only( dim_t ways, rntm_t* rntm ) { bli_rntm_set_ways_for_only( BLIS_MR, ways, rntm ); } -static void bli_rntm_set_pr_ways_only( dim_t ways, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_pr_ways_only( dim_t ways, rntm_t* rntm ) { bli_rntm_set_ways_for_only( BLIS_KR, ways, rntm ); } -static void bli_rntm_set_ways_only( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_ways_only( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, rntm_t* rntm ) { // Record the number of ways of parallelism per loop. bli_rntm_set_jc_ways_only( jc, rntm ); @@ -165,34 +200,38 @@ static void bli_rntm_set_ways_only( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_ bli_rntm_set_pr_ways_only( 1, rntm ); } -static void bli_rntm_set_sba_pool( pool_t* sba_pool, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_sba_pool( pool_t* sba_pool, rntm_t* rntm ) { rntm->sba_pool = sba_pool; } -static void bli_rntm_set_membrk( membrk_t* membrk, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_pba( pba_t* pba, rntm_t* rntm ) { - rntm->membrk = membrk; + rntm->pba = pba; } -static void bli_rntm_clear_num_threads_only( rntm_t* rntm ) +BLIS_INLINE void bli_rntm_clear_num_threads_only( rntm_t* rntm ) { bli_rntm_set_num_threads_only( -1, rntm ); } -static void bli_rntm_clear_ways_only( rntm_t* rntm ) +BLIS_INLINE void bli_rntm_clear_ways_only( rntm_t* rntm ) { bli_rntm_set_ways_only( -1, -1, -1, -1, -1, rntm ); } -static void bli_rntm_clear_sba_pool( rntm_t* rntm ) +BLIS_INLINE void bli_rntm_clear_sba_pool( rntm_t* rntm ) { bli_rntm_set_sba_pool( NULL, rntm ); } +BLIS_INLINE void bli_rntm_clear_pba( rntm_t* rntm ) +{ + bli_rntm_set_pba( NULL, rntm ); +} // // -- rntm_t modification (public API) ----------------------------------------- // -static void bli_rntm_set_num_threads( dim_t nt, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_num_threads( dim_t nt, rntm_t* rntm ) { // Record the total number of threads to use. bli_rntm_set_num_threads_only( nt, rntm ); @@ -201,7 +240,7 @@ static void bli_rntm_set_num_threads( dim_t nt, rntm_t* rntm ) bli_rntm_clear_ways_only( rntm ); } -static void bli_rntm_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, rntm_t* rntm ) +BLIS_INLINE void bli_rntm_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, rntm_t* rntm ) { // Record the number of ways of parallelism per loop. bli_rntm_set_jc_ways_only( jc, rntm ); @@ -215,6 +254,48 @@ static void bli_rntm_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, bli_rntm_clear_num_threads_only( rntm ); } +BLIS_INLINE void bli_rntm_set_pack_a( bool pack_a, rntm_t* rntm ) +{ + // Set the bool indicating whether matrix A should be packed. + rntm->pack_a = pack_a; +} +BLIS_INLINE void bli_rntm_set_pack_b( bool pack_b, rntm_t* rntm ) +{ + // Set the bool indicating whether matrix B should be packed. + rntm->pack_b = pack_b; +} + +BLIS_INLINE void bli_rntm_set_l3_sup( bool l3_sup, rntm_t* rntm ) +{ + // Set the bool indicating whether level-3 sup handling is enabled. + rntm->l3_sup = l3_sup; +} +BLIS_INLINE void bli_rntm_enable_l3_sup( rntm_t* rntm ) +{ + bli_rntm_set_l3_sup( TRUE, rntm ); +} +BLIS_INLINE void bli_rntm_disable_l3_sup( rntm_t* rntm ) +{ + bli_rntm_set_l3_sup( FALSE, rntm ); +} + +// +// -- rntm_t modification (internal use only) ---------------------------------- +// + +BLIS_INLINE void bli_rntm_clear_pack_a( rntm_t* rntm ) +{ + bli_rntm_set_pack_a( FALSE, rntm ); +} +BLIS_INLINE void bli_rntm_clear_pack_b( rntm_t* rntm ) +{ + bli_rntm_set_pack_b( FALSE, rntm ); +} +BLIS_INLINE void bli_rntm_clear_l3_sup( rntm_t* rntm ) +{ + bli_rntm_set_l3_sup( TRUE, rntm ); +} + // // -- rntm_t initialization ---------------------------------------------------- // @@ -223,23 +304,57 @@ static void bli_rntm_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, // of the public "set" accessors, each of which guarantees that the rntm_t // will be in a good state upon return. -#define BLIS_RNTM_INITIALIZER { .num_threads = -1, \ - .thrloop = { -1, -1, -1, -1, -1, -1 }, \ - .sba_pool = NULL } \ - -static void bli_rntm_init( rntm_t* rntm ) +#define BLIS_RNTM_INITIALIZER \ + { \ + .auto_factor = TRUE, \ + .num_threads = -1, \ + .thrloop = { -1, -1, -1, -1, -1, -1 }, \ + .pack_a = FALSE, \ + .pack_b = FALSE, \ + .l3_sup = TRUE, \ + .sba_pool = NULL, \ + .pba = NULL, \ + } \ + +BLIS_INLINE void bli_rntm_init( rntm_t* rntm ) { + bli_rntm_set_auto_factor_only( TRUE, rntm ); + bli_rntm_clear_num_threads_only( rntm ); bli_rntm_clear_ways_only( rntm ); + bli_rntm_clear_pack_a( rntm ); + bli_rntm_clear_pack_b( rntm ); + bli_rntm_clear_l3_sup( rntm ); bli_rntm_clear_sba_pool( rntm ); + bli_rntm_clear_pba( rntm ); +} + +// -- rntm_t total thread calculation ------------------------------------------ + +BLIS_INLINE dim_t bli_rntm_calc_num_threads + ( + rntm_t* restrict rntm + ) +{ + dim_t n_threads; + + n_threads = bli_rntm_ways_for( BLIS_NC, rntm ); + n_threads *= bli_rntm_ways_for( BLIS_KC, rntm ); + n_threads *= bli_rntm_ways_for( BLIS_MC, rntm ); + n_threads *= bli_rntm_ways_for( BLIS_NR, rntm ); + n_threads *= bli_rntm_ways_for( BLIS_MR, rntm ); + + return n_threads; } // ----------------------------------------------------------------------------- // Function prototypes -void bli_rntm_set_ways_for_op +BLIS_EXPORT_BLIS void bli_rntm_init_from_global( rntm_t* rntm ); + +BLIS_EXPORT_BLIS void bli_rntm_set_ways_for_op ( opid_t l3_op, side_t side, @@ -257,10 +372,24 @@ void bli_rntm_set_ways_from_rntm rntm_t* rntm ); +void bli_rntm_set_ways_from_rntm_sup + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ); + void bli_rntm_print ( rntm_t* rntm ); +dim_t bli_rntm_calc_num_threads_in + ( + bszid_t* restrict bszid_cur, + rntm_t* restrict rntm + ); + #endif diff --git a/frame/base/bli_sba.c b/frame/base/bli_sba.c index 1e2d5753ff..5b6ff6a0f0 100644 --- a/frame/base/bli_sba.c +++ b/frame/base/bli_sba.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,8 +34,9 @@ #include "blis.h" -// The small block allocator: an apool_t of array_t of pool_t. -static apool_t sba; +// Statically initialize the mutex within the small block allocator. +// Note that the sba is an apool_t of array_t of pool_t. +static apool_t sba = { .mutex = BLIS_PTHREAD_MUTEX_INITIALIZER }; apool_t* bli_sba_query( void ) { @@ -61,11 +62,12 @@ void* bli_sba_acquire ) { void* block; + err_t r_val; #ifdef BLIS_ENABLE_SBA_POOLS if ( rntm == NULL ) { - block = bli_malloc_intl( req_size ); + block = bli_malloc_intl( req_size, &r_val ); } else { @@ -74,28 +76,43 @@ void* bli_sba_acquire // Query the small block pool from the rntm. pool_t* restrict pool = bli_rntm_sba_pool( rntm ); - // Query the block_size of the pool_t so that we can request the exact - // size present. - const siz_t block_size = bli_pool_block_size( pool ); - - // Sanity check: Make sure the requested size is no larger than the - // block_size field of the pool. - if ( block_size < req_size ) + // We don't expect NULL sba_pool pointers in the normal course of BLIS + // operation. However, there are rare instances where it is convenient + // to support use of bli_sba_acquire() without having to pass in a valid + // sba pool data structure. The case that inspired this branch was the + // gemm_ukr and related test modules in the BLIS testsuite. (There, it + // is convenient to not have to checkout an array_t from the sba, and it + // does no harm since the malloc() happens outside of the region that + // would be timed.) + if ( pool == NULL ) { - printf( "bli_sba_acquire(): ** pool block_size is %d but req_size is %d.\n", - ( int )block_size, ( int )req_size ); - bli_abort(); + block = bli_malloc_intl( req_size, &r_val ); + } + else + { + // Query the block_size of the pool_t so that we can request the exact + // size present. + const siz_t block_size = bli_pool_block_size( pool ); + + // Sanity check: Make sure the requested size is no larger than the + // block_size field of the pool. + if ( block_size < req_size ) + { + printf( "bli_sba_acquire(): ** pool block_size is %d but req_size is %d.\n", + ( int )block_size, ( int )req_size ); + bli_abort(); + } + + // Check out a block using the block_size queried above. + bli_pool_checkout_block( block_size, &pblk, pool ); + + // The block address is stored within the pblk_t. + block = bli_pblk_buf( &pblk ); } - - // Check out a block using the block_size queried above. - bli_pool_checkout_block( block_size, &pblk, pool ); - - // The block address is stored within the pblk_t. - block = bli_pblk_buf( &pblk ); } #else - block = bli_malloc_intl( req_size ); + block = bli_malloc_intl( req_size, &r_val ); #endif @@ -121,21 +138,28 @@ void bli_sba_release // Query the small block pool from the rntm. pool_t* restrict pool = bli_rntm_sba_pool( rntm ); - // Query the block_size field from the pool. This is not super-important - // for this particular application of the pool_t (that is, the "leaf" - // component of the sba), but it seems like good housekeeping to maintain - // the block_size field of the pblk_t in case its ever needed/read. - const siz_t block_size = bli_pool_block_size( pool ); - - // Embed the block's memory address into a pblk_t, along with the - // block_size queried from the pool. - bli_pblk_set_buf( block, &pblk ); - bli_pblk_set_block_size( block_size, &pblk ); - - // Check the pblk_t back into the pool_t. (It's okay that the pblk_t is - // a local variable since its contents are copied into the pool's internal - // data structure--an array of pblk_t.) - bli_pool_checkin_block( &pblk, pool ); + if ( pool == NULL ) + { + bli_free_intl( block ); + } + else + { + // Query the block_size field from the pool. This is not super-important + // for this particular application of the pool_t (that is, the "leaf" + // component of the sba), but it seems like good housekeeping to maintain + // the block_size field of the pblk_t in case its ever needed/read. + const siz_t block_size = bli_pool_block_size( pool ); + + // Embed the block's memory address into a pblk_t, along with the + // block_size queried from the pool. + bli_pblk_set_buf( block, &pblk ); + bli_pblk_set_block_size( block_size, &pblk ); + + // Check the pblk_t back into the pool_t. (It's okay that the pblk_t is + // a local variable since its contents are copied into the pool's internal + // data structure--an array of pblk_t.) + bli_pool_checkin_block( &pblk, pool ); + } } #else diff --git a/frame/base/bli_sba.h b/frame/base/bli_sba.h index cf10834e31..f5e36d759a 100644 --- a/frame/base/bli_sba.h +++ b/frame/base/bli_sba.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_setgetij.c b/frame/base/bli_setgetijm.c similarity index 87% rename from frame/base/bli_setgetij.c rename to frame/base/bli_setgetijm.c index 744e24c27e..78ff58a29c 100644 --- a/frame/base/bli_setgetij.c +++ b/frame/base/bli_setgetijm.c @@ -59,9 +59,9 @@ err_t bli_setijm dim_t cs = bli_obj_col_stride( b ); num_t dt = bli_obj_dt( b ); - // Return error if i or j is beyond bounds of matrix/vector. - if ( m <= i ) return BLIS_FAILURE; - if ( n <= j ) return BLIS_FAILURE; + // Return error if i or j is beyond bounds of the matrix/vector. + if ( i < 0 || m <= i ) return BLIS_FAILURE; + if ( j < 0 || n <= j ) return BLIS_FAILURE; // Don't modify scalar constants. if ( dt == BLIS_CONSTANT ) return BLIS_FAILURE; @@ -133,35 +133,15 @@ err_t bli_getijm dim_t cs = bli_obj_col_stride( b ); num_t dt = bli_obj_dt( b ); - // Return error if i or j is beyond bounds of matrix/vector. - if ( m <= i ) return BLIS_FAILURE; - if ( n <= j ) return BLIS_FAILURE; - - void* b_p; - -#if 0 - // Handle scalar constants separately. - if ( dt == BLIS_CONSTANT ) - { - if ( i == 0 && j == 0 ) - { - dt = BLIS_DCOMPLEX; - b_p = bli_obj_buffer_for_const( dt, b ) - } - else return BLIS_FAILURE; - } - else - { - // Query the pointer to the buffer at the adjusted offsets. - b_p = bli_obj_buffer_at_off( b ); - } -#else + // Return error if i or j is beyond bounds of the matrix/vector. + if ( i < 0 || m <= i ) return BLIS_FAILURE; + if ( j < 0 || n <= j ) return BLIS_FAILURE; + // Disallow access into scalar constants. if ( dt == BLIS_CONSTANT ) return BLIS_FAILURE; // Query the pointer to the buffer at the adjusted offsets. - b_p = bli_obj_buffer_at_off( b ); -#endif + void* b_p = bli_obj_buffer_at_off( b ); // Index into the function pointer array. getijm_fp f = ftypes_getijm[ dt ]; diff --git a/frame/base/bli_setgetij.h b/frame/base/bli_setgetijm.h similarity index 100% rename from frame/base/bli_setgetij.h rename to frame/base/bli_setgetijm.h diff --git a/frame/base/bli_setgetijv.c b/frame/base/bli_setgetijv.c new file mode 100644 index 0000000000..610f6f271c --- /dev/null +++ b/frame/base/bli_setgetijv.c @@ -0,0 +1,168 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +typedef void (*setijv_fp) + ( + double ar, + double ai, + dim_t i, + void* restrict x, inc_t incx + ); +static setijv_fp GENARRAY(ftypes_setijv,setijv); + +err_t bli_setijv + ( + double ar, + double ai, + dim_t i, + obj_t* x + ) +{ + dim_t n = bli_obj_vector_dim( x ); + dim_t incx = bli_obj_vector_inc( x ); + num_t dt = bli_obj_dt( x ); + + // Return error if i is beyond bounds of the vector. + if ( i < 0 || n <= i ) return BLIS_FAILURE; + + // Don't modify scalar constants. + if ( dt == BLIS_CONSTANT ) return BLIS_FAILURE; + + // Query the pointer to the buffer at the adjusted offsets. + void* x_p = bli_obj_buffer_at_off( x ); + + // Index into the function pointer array. + setijv_fp f = ftypes_setijv[ dt ]; + + // Invoke the type-specific function. + f + ( + ar, + ai, + i, + x_p, incx + ); + + return BLIS_SUCCESS; +} + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + double ar, \ + double ai, \ + dim_t i, \ + void* restrict x, inc_t incx \ + ) \ +{ \ + ctype* restrict x_cast = ( ctype* )x; \ +\ + ctype* restrict x_i = x_cast + (i )*incx; \ +\ + PASTEMAC2(z,ch,sets)( ar, ai, *x_i ); \ +} + +INSERT_GENTFUNC_BASIC0( setijv ) + +// ----------------------------------------------------------------------------- + +typedef void (*getijv_fp) + ( + dim_t i, + void* restrict x, inc_t incx, + double* ar, + double* ai + ); +static getijv_fp GENARRAY(ftypes_getijv,getijv); + +err_t bli_getijv + ( + dim_t i, + obj_t* x, + double* ar, + double* ai + ) +{ + dim_t n = bli_obj_vector_dim( x ); + dim_t incx = bli_obj_vector_inc( x ); + num_t dt = bli_obj_dt( x ); + + // Return error if i is beyond bounds of the vector. + if ( i < 0 || n <= i ) return BLIS_FAILURE; + + // Disallow access into scalar constants. + if ( dt == BLIS_CONSTANT ) return BLIS_FAILURE; + + // Query the pointer to the buffer at the adjusted offsets. + void* x_p = bli_obj_buffer_at_off( x ); + + // Index into the function pointer array. + getijv_fp f = ftypes_getijv[ dt ]; + + // Invoke the type-specific function. + f + ( + i, + x_p, incx, + ar, + ai + ); + + return BLIS_SUCCESS; +} + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + dim_t i, \ + void* restrict x, inc_t incx, \ + double* ar, \ + double* ai \ + ) \ +{ \ + ctype* restrict x_cast = ( ctype* )x; \ +\ + ctype* restrict x_i = x_cast + (i )*incx; \ +\ + PASTEMAC2(ch,z,gets)( *x_i, *ar, *ai ); \ +} + +INSERT_GENTFUNC_BASIC0( getijv ) + diff --git a/frame/base/bli_setgetijv.h b/frame/base/bli_setgetijv.h new file mode 100644 index 0000000000..703fe41aae --- /dev/null +++ b/frame/base/bli_setgetijv.h @@ -0,0 +1,78 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +BLIS_EXPORT_BLIS err_t bli_setijv + ( + double ar, + double ai, + dim_t i, + obj_t* x + ); + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + double ar, \ + double ai, \ + dim_t i, \ + void* restrict x, inc_t incx \ + ); + +INSERT_GENTPROT_BASIC0( setijv ) + +// ----------------------------------------------------------------------------- + +BLIS_EXPORT_BLIS err_t bli_getijv + ( + dim_t i, + obj_t* x, + double* ar, + double* ai + ); + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + dim_t i, \ + void* restrict b, inc_t incx, \ + double* ar, \ + double* ai \ + ); + +INSERT_GENTPROT_BASIC0( getijv ) + diff --git a/frame/base/bli_winsys.h b/frame/base/bli_winsys.h index 0ad7c408c7..0c71114ad0 100644 --- a/frame/base/bli_winsys.h +++ b/frame/base/bli_winsys.h @@ -33,5 +33,5 @@ */ //int bli_setenv( const char *name, const char *value, int overwrite ); -void bli_sleep( unsigned int secs ); +BLIS_EXPORT_BLIS void bli_sleep( unsigned int secs ); diff --git a/frame/base/noopt/bli_dlamch.c b/frame/base/noopt/bli_dlamch.c index 53a6609653..b8be23b382 100644 --- a/frame/base/noopt/bli_dlamch.c +++ b/frame/base/noopt/bli_dlamch.c @@ -1,12 +1,14 @@ -/* dlamch.f -- translated by f2c (version 19991025). - You must link the resulting object file with the libraries: - -lf2c -lm (in that order) -*/ +#include "blis.h" + +#include +#include +#include #ifdef __cplusplus extern "C" { #endif -#include "blis.h" + +#ifdef BLIS_ENABLE_LEGACY_LAMCH double bli_pow_di( bla_double* a, bla_integer* n ); @@ -1027,6 +1029,59 @@ bla_double bli_dlamc3(bla_double *a, bla_double *b) } /* bli_dlamc5_ */ -#ifdef __cplusplus +#else + +bla_double bli_dlamch(bla_character *cmach, ftnlen cmach_len) +{ +/* = 'E' or 'e', DLAMCH := eps */ +/* = 'S' or 's , DLAMCH := sfmin */ +/* = 'B' or 'b', DLAMCH := base */ +/* = 'P' or 'p', DLAMCH := eps*base */ +/* = 'N' or 'n', DLAMCH := t */ +/* = 'R' or 'r', DLAMCH := rnd */ +/* = 'M' or 'm', DLAMCH := emin */ +/* = 'U' or 'u', DLAMCH := rmin */ +/* = 'L' or 'l', DLAMCH := emax */ +/* = 'O' or 'o', DLAMCH := rmax */ + +/* where */ + +/* eps = relative machine precision */ +/* sfmin = safe minimum, such that 1/sfmin does not overflow */ +/* base = base of the machine */ +/* prec = eps*base */ +/* t = number of (base) digits in the mantissa */ +/* rnd = 1.0 when rounding occurs in addition, 0.0 otherwise */ +/* emin = minimum exponent before (gradual) underflow */ +/* rmin = underflow threshold - base**(emin-1) */ +/* emax = largest exponent before overflow */ +/* rmax = overflow threshold - (base**emax)*(1-eps) */ + + double safe_min = DBL_MIN; + double small = 1.0f / DBL_MAX; + + if ( small >= safe_min ) + safe_min = small * ( 1.0 + DBL_EPSILON ); + + switch ( toupper( *cmach ) ) + { + case 'E': return DBL_EPSILON; + case 'S': return safe_min; + case 'B': return FLT_RADIX; + case 'P': return FLT_RADIX*DBL_EPSILON; + case 'N': return DBL_MANT_DIG; + case 'R': return FLT_ROUNDS == FE_TONEAREST ? 1.0 : 0.0; + case 'M': return DBL_MIN_EXP; + case 'U': return DBL_MIN; + case 'L': return DBL_MAX_EXP; + case 'O': return DBL_MAX; } + + return 0.0; +} + +#endif + +#ifdef __cplusplus +} #endif diff --git a/frame/base/noopt/bli_slamch.c b/frame/base/noopt/bli_slamch.c index 3f0b72cd8c..ec7cf85975 100644 --- a/frame/base/noopt/bli_slamch.c +++ b/frame/base/noopt/bli_slamch.c @@ -1,12 +1,14 @@ -/* slamch.f -- translated by f2c (version 19991025). - You must link the resulting object file with the libraries: - -lf2c -lm (in that order) -*/ +#include "blis.h" + +#include +#include +#include #ifdef __cplusplus extern "C" { #endif -#include "blis.h" + +#ifdef BLIS_ENABLE_LEGACY_LAMCH double bli_pow_ri( bla_real* a, bla_integer* n ); @@ -1022,6 +1024,59 @@ bla_real bli_slamc3(bla_real *a, bla_real *b) } /* bli_slamc5_ */ -#ifdef __cplusplus +#else + +bla_real bli_slamch(bla_character *cmach, ftnlen cmach_len) +{ +/* = 'E' or 'e', SLAMCH := eps */ +/* = 'S' or 's , SLAMCH := sfmin */ +/* = 'B' or 'b', SLAMCH := base */ +/* = 'P' or 'p', SLAMCH := eps*base */ +/* = 'N' or 'n', SLAMCH := t */ +/* = 'R' or 'r', SLAMCH := rnd */ +/* = 'M' or 'm', SLAMCH := emin */ +/* = 'U' or 'u', SLAMCH := rmin */ +/* = 'L' or 'l', SLAMCH := emax */ +/* = 'O' or 'o', SLAMCH := rmax */ + +/* where */ + +/* eps = relative machine precision */ +/* sfmin = safe minimum, such that 1/sfmin does not overflow */ +/* base = base of the machine */ +/* prec = eps*base */ +/* t = number of (base) digits in the mantissa */ +/* rnd = 1.0 when rounding occurs in addition, 0.0 otherwise */ +/* emin = minimum exponent before (gradual) underflow */ +/* rmin = underflow threshold - base**(emin-1) */ +/* emax = largest exponent before overflow */ +/* rmax = overflow threshold - (base**emax)*(1-eps) */ + + float safe_min = FLT_MIN; + float small = 1.0f / FLT_MAX; + + if ( small >= safe_min ) + safe_min = small * ( 1.0f + FLT_EPSILON ); + + switch ( toupper( *cmach ) ) + { + case 'E': return FLT_EPSILON; + case 'S': return safe_min; + case 'B': return FLT_RADIX; + case 'P': return FLT_RADIX*FLT_EPSILON; + case 'N': return FLT_MANT_DIG; + case 'R': return FLT_ROUNDS == FE_TONEAREST ? 1.0f : 0.0f; + case 'M': return FLT_MIN_EXP; + case 'U': return FLT_MIN; + case 'L': return FLT_MAX_EXP; + case 'O': return FLT_MAX; } + + return 0.0f; +} + +#endif + +#ifdef __cplusplus +} #endif diff --git a/frame/compat/amd/bla_copy_amd.c b/frame/compat/amd/bla_copy_amd.c new file mode 100644 index 0000000000..6780b555e6 --- /dev/null +++ b/frame/compat/amd/bla_copy_amd.c @@ -0,0 +1,147 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname, isuf ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + /*bli_init_auto()*/; \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + /* NOTE: While we skip explicit initialization for real domain instances + since we call the microkernel directly, the complex domain instances + still need initialization so that they can query valid contexts from + gks. However, the expert API will self-initialize before attempting + to query a context, so the complex domain cases should work fine. */ \ + PASTEMAC2(ch,blisname,isuf) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + /*bli_finalize_auto();*/ \ +} + +#ifdef BLIS_ENABLE_BLAS +//INSERT_GENTFUNC_BLAS( copy, copyv ) +GENTFUNC( float, s, copy, copyv, _zen_int ) +GENTFUNC( double, d, copy, copyv, _zen_int ) +#endif + + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname, isuf ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + /*bli_init_auto()*/; \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + /* NOTE: While we skip explicit initialization for real domain instances + since we call the microkernel directly, the complex domain instances + still need initialization so that they can query valid contexts from + gks. However, the expert API will self-initialize before attempting + to query a context, so the complex domain cases should work fine. */ \ + PASTEMAC2(ch,blisname,isuf) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + /*bli_finalize_auto();*/ \ +} + +#ifdef BLIS_ENABLE_BLAS +//INSERT_GENTFUNC_BLAS( copy, copyv ) +GENTFUNC( scomplex, c, copy, copyv, _ex ) +GENTFUNC( dcomplex, z, copy, copyv, _ex ) +#endif + diff --git a/frame/compat/amd/bla_gemv_amd.c b/frame/compat/amd/bla_gemv_amd.c new file mode 100644 index 0000000000..398d1bf2c2 --- /dev/null +++ b/frame/compat/amd/bla_gemv_amd.c @@ -0,0 +1,172 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + trans_t blis_transa; \ + dim_t m0, n0; \ + dim_t m_y, n_x; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + /*bli_init_auto();*/ \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + m, \ + n, \ + lda, \ + incx, \ + incy \ + ); \ +\ + /* BLAS handles cases where y has no elements as well as those where x has + no elements. In the case of the former, it cannot do any work since + the output vector is empty; but in the latter case, BLAS has peculiar + semantics. When x has no elements (and transa(A) has no columns), BLAS + returns immediately without performing any computation even if the + number of elements of y (and rows of transa(A)) is non-zero, in which + case any sane interpretations of gemv would have the the operation + reduce to y := beta * y. Here, we emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be contemplated if it weren't for the fact + that some BLAS unit tests actually check for this behavior. Also, it + should be emphasized that BLIS, when called natively, does NOT exhibit + this quirky behavior; it will scale y by beta as one would expect. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + /* Finalize BLIS. */ \ + /*bli_finalize_auto();*/ \ +\ + return; \ + } \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Convert/typecast negative values of m and n to zero. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ \ + bli_set_dims_with_trans( blis_transa, m0, n0, &m_y, &n_x ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n_x, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( m_y, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* If alpha is zero, scale y by beta and return early. */ \ + if ( PASTEMAC(ch,eq0)( *alpha ) ) \ + { \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m_y, \ + ( ftype* )beta, \ + ( ftype* )y0, incy0, \ + NULL, \ + NULL \ + ); \ + return; \ + } \ +\ + /* Set the row and column strides of A. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ +\ + /* Declare a void function pointer for the current operation. */ \ + PASTECH2(ch,blisname,_unb_ft) f; \ +\ + /* Choose the underlying implementation. */ \ + if ( bli_does_notrans( blis_transa ) ) f = PASTEMAC(ch,gemv_unf_var2); \ + else /* if ( bli_does_trans( blis_transa ) ) */ f = PASTEMAC(ch,gemv_unf_var1); \ +\ + /* Obtain a valid context from the gks. This is needed because these + implementations of ?gemv_() skip calling gemv_ex() and instead + call the unblocked fused variants directly. */ \ + cntx_t* cntx = bli_gks_query_cntx(); \ +\ + /* Invoke the variant chosen above, which loops over a level-1v or + level-1f kernel to implement the current operation. */ \ + f \ + ( \ + blis_transa, \ + BLIS_NO_CONJUGATE, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + cntx \ + ); \ +\ + /* Finalize BLIS. */ \ + /*bli_finalize_auto();*/ \ +} + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( gemv, gemv ) +#endif + diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index dbab039d14..f5396b1902 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -34,6 +34,7 @@ #include "blis.h" +#ifdef BLIS_ENABLE_BLAS // // Define BLAS-to-BLIS interfaces. @@ -85,8 +86,67 @@ ftype PASTEF772(ch,blasname,chc) \ return rho; \ } -#ifdef BLIS_ENABLE_BLAS -INSERT_GENTFUNCDOT_BLAS( dot, dotv ) +INSERT_GENTFUNCDOTR_BLAS( dot, dotv ) + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + +INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) + +#else // #ifdef BLIS_ENABLE_COMPLEX_RETURN_INTEL + +// For the "intel" complex return type, use a hidden preceding parameter to +// return the result rather than an actual return value. +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + *rhop = rho; \ +} + +INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) + +#endif // -- "Black sheep" dot product function definitions -- @@ -101,10 +161,16 @@ float PASTEF77(sd,sdot) const float* y, const f77_int* incy ) { - float r = ( float )PASTEF77(d,sdot)( n, - x, incx, - y, incy ); - return r + *sb; + return ( float ) + ( + ( double )(*sb) + + PASTEF77(d,sdot) + ( + n, + x, incx, + y, incy + ) + ); } // Input vectors stored in single precision, computed in double precision, diff --git a/frame/compat/bla_dot.h b/frame/compat/bla_dot.h index 373e1a7b72..87d7733214 100644 --- a/frame/compat/bla_dot.h +++ b/frame/compat/bla_dot.h @@ -32,6 +32,7 @@ */ +#ifdef BLIS_ENABLE_BLAS // // Prototype BLAS-to-BLIS interfaces. @@ -46,8 +47,30 @@ BLIS_EXPORT_BLAS ftype PASTEF772(ch,blasname,chc) \ const ftype* y, const f77_int* incy \ ); -#ifdef BLIS_ENABLE_BLAS -INSERT_GENTPROTDOT_BLAS( dot ) +INSERT_GENTPROTDOTR_BLAS( dot ) + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + +INSERT_GENTPROTDOTC_BLAS( dot ) + +#else + +// For the "intel" complex return type, we use a hidden parameter (passed by +// address) to return the result. +#undef GENTPROTDOT +#define GENTPROTDOT( ftype, ch, chc, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ); + +INSERT_GENTPROTDOTC_BLAS( dot ) + +#endif // -- "Black sheep" dot product function prototypes -- @@ -66,4 +89,5 @@ BLIS_EXPORT_BLAS double PASTEF77(d,sdot) const float* x, const f77_int* incx, const float* y, const f77_int* incy ); + #endif diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 1effececa1..e71d4e2fcc 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -58,9 +62,6 @@ void PASTEF77(ch,blasname) \ trans_t blis_transa; \ trans_t blis_transb; \ dim_t m0, n0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -84,18 +85,18 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ \ - /* Convert/typecast negative values of m, n, and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -118,6 +119,147 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + /* Handle special cases of m == 1 or n == 1 via gemv. */ \ + if ( n0 == 1 ) \ + { \ + dim_t m0t, k0t; \ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0t, &k0t ); \ +\ + PASTEMAC2(ch,gemv,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + bli_extract_conj( blis_transb ), \ + m0t, k0t, \ + ( ftype* )alpha, \ + ( ftype* )a, rs_a, cs_a, \ + ( ftype* )b, ( bli_does_notrans( blis_transb ) ? rs_b : cs_b ), \ + ( ftype* )beta, \ + c, rs_c, \ + NULL, \ + NULL \ + ); \ + return; \ + } \ + else if ( m0 == 1 ) \ + { \ + dim_t n0t, k0t; \ + bli_set_dims_with_trans( blis_transb, n0, k0, &n0t, &k0t ); \ +\ + PASTEMAC2(ch,gemv,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transb, \ + bli_extract_conj( blis_transa ), \ + n0t, k0t, \ + ( ftype* )alpha, \ + ( ftype* )b, cs_b, rs_b, \ + ( ftype* )a, ( bli_does_notrans( blis_transa ) ? cs_a : rs_a ), \ + ( ftype* )beta, \ + c, cs_c, \ + NULL, \ + NULL \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemm, gemm ) #endif diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index 85c65dde41..8d730edd9c 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -60,7 +61,6 @@ void PASTEF77(ch,blasname) \ ftype* y0; \ inc_t incx0; \ inc_t incy0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -89,16 +89,19 @@ void PASTEF77(ch,blasname) \ if necessary.*/ \ bli_set_dims_with_trans( blis_transa, m0, n0, &m_y, &n_x ); \ \ - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to + /* BLAS handles cases where y has no elements as well as those where x has + no elements. In the case of the former, it cannot do any work since + the output vector is empty; but in the latter case, BLAS has peculiar + semantics. When x has no elements (and transa(A) has no columns), BLAS + returns immediately without performing any computation even if the + number of elements of y (and rows of transa(A)) is non-zero, in which + case any sane interpretations of gemv would have the the operation + reduce to y := beta * y. Here, we emulate the BLAS exactly so as to provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ \ + compatibility would not be contemplated if it weren't for the fact + that some BLAS unit tests actually check for this behavior. Also, it + should be emphasized that BLIS, when called natively, does NOT exhibit + this quirky behavior; it will scale y by beta as one would expect. */ \ if ( m_y > 0 && n_x == 0 ) \ { \ /* Finalize BLIS. */ \ @@ -113,8 +116,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m_y, (ftype*)y, *incy, y0, incy0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_ger.c b/frame/compat/bla_ger.c index db4f76f184..b558bfd942 100644 --- a/frame/compat/bla_ger.c +++ b/frame/compat/bla_ger.c @@ -56,7 +56,6 @@ void PASTEF772(ch,blasname,chc) \ ftype* y0; \ inc_t incx0; \ inc_t incy0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -84,8 +83,8 @@ void PASTEF772(ch,blasname,chc) \ bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index 88e9c8b557..9a4484a091 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ @@ -57,9 +61,6 @@ void PASTEF77(ch,blasname) \ side_t blis_side; \ uplo_t blis_uploa; \ dim_t m0, n0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -82,17 +83,17 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -116,6 +117,110 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + m, \ + n, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const conj_t conja = BLIS_NO_CONJUGATE; \ + const trans_t transb = BLIS_NO_TRANSPOSE; \ + const struc_t struca = BLIS_HERMITIAN; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ + bli_set_dims_with_trans( transb, m0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_conj( conja, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( hemm, hemm ) #endif diff --git a/frame/compat/bla_hemv.c b/frame/compat/bla_hemv.c index 9444682784..d036c10e3a 100644 --- a/frame/compat/bla_hemv.c +++ b/frame/compat/bla_hemv.c @@ -58,7 +58,6 @@ void PASTEF77(ch,blasname) \ ftype* y0; \ inc_t incx0; \ inc_t incy0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -87,8 +86,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)y, *incy, y0, incy0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_her.c b/frame/compat/bla_her.c index ade3cbdda0..512081d89a 100644 --- a/frame/compat/bla_her.c +++ b/frame/compat/bla_her.c @@ -54,7 +54,6 @@ void PASTEF77(ch,blasname) \ dim_t m0; \ ftype* x0; \ inc_t incx0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -81,8 +80,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)x, *incx, x0, incx0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_her2.c b/frame/compat/bla_her2.c index e3ed4ce31b..7d99a6378c 100644 --- a/frame/compat/bla_her2.c +++ b/frame/compat/bla_her2.c @@ -57,7 +57,6 @@ void PASTEF77(ch,blasname) \ ftype* y0; \ inc_t incx0; \ inc_t incy0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -86,8 +85,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)y, *incy, y0, incy0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_her2k.c b/frame/compat/bla_her2k.c index 0bbe98e1c2..2a058dc021 100644 --- a/frame/compat/bla_her2k.c +++ b/frame/compat/bla_her2k.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ @@ -57,9 +61,6 @@ void PASTEF77(ch,blasname) \ uplo_t blis_uploc; \ trans_t blis_transa; \ dim_t m0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -82,7 +83,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ @@ -104,12 +105,12 @@ void PASTEF77(ch,blasname) \ } \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -132,6 +133,126 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* We emulate the BLAS early return behavior with the following + conditional, which returns if one of the following is true: + - matrix C is empty + - the rank-2k product is empty (either because alpha is zero or k + is zero) AND matrix C is not scaled. */ \ + if ( m0 == 0 || \ + ( ( PASTEMAC(ch,eq0)( *alpha ) || k0 == 0 ) \ + && PASTEMAC(chr,eq1)( *beta ) \ + ) \ + ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const trans_t transb = blis_transa; \ + const struc_t strucc = BLIS_HERMITIAN; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( transb, m0, k0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype* )alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, (ftype_r*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( her2k, her2k ) #endif diff --git a/frame/compat/bla_herk.c b/frame/compat/bla_herk.c index 88185de0ba..8236e20329 100644 --- a/frame/compat/bla_herk.c +++ b/frame/compat/bla_herk.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ @@ -56,8 +60,6 @@ void PASTEF77(ch,blasname) \ uplo_t blis_uploc; \ trans_t blis_transa; \ dim_t m0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -79,7 +81,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ @@ -101,10 +103,10 @@ void PASTEF77(ch,blasname) \ } \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -125,6 +127,115 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype_r* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* We emulate the BLAS early return behavior with the following + conditional, which returns if one of the following is true: + - matrix C is empty + - the rank-k product is empty (either because alpha is zero or k + is zero) AND matrix C is not scaled. */ \ + if ( m0 == 0 || \ + ( ( PASTEMAC(chr,eq0)( *alpha ) || k0 == 0 ) \ + && PASTEMAC(chr,eq1)( *beta ) \ + ) \ + ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t strucc = BLIS_HERMITIAN; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ +\ + bli_obj_init_finish_1x1( dt_r, (ftype_r*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, (ftype_r*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( herk, herk ) #endif diff --git a/frame/compat/bla_symm.c b/frame/compat/bla_symm.c index 02d3a3b278..098beb4727 100644 --- a/frame/compat/bla_symm.c +++ b/frame/compat/bla_symm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -57,9 +61,6 @@ void PASTEF77(ch,blasname) \ side_t blis_side; \ uplo_t blis_uploa; \ dim_t m0, n0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -82,17 +83,17 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -116,6 +117,110 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + m, \ + n, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const conj_t conja = BLIS_NO_CONJUGATE; \ + const trans_t transb = BLIS_NO_TRANSPOSE; \ + const struc_t struca = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ + bli_set_dims_with_trans( transb, m0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_conj( conja, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( symm, symm ) #endif diff --git a/frame/compat/bla_symv.c b/frame/compat/bla_symv.c index 79076194c7..c5b5ebda37 100644 --- a/frame/compat/bla_symv.c +++ b/frame/compat/bla_symv.c @@ -58,7 +58,6 @@ void PASTEF77(ch,blasname) \ ftype* y0; \ inc_t incx0; \ inc_t incy0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -87,8 +86,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)y, *incy, y0, incy0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_syr.c b/frame/compat/bla_syr.c index 0ed4aebb1a..6732a75cf2 100644 --- a/frame/compat/bla_syr.c +++ b/frame/compat/bla_syr.c @@ -54,7 +54,6 @@ void PASTEF77(ch,blasname) \ dim_t m0; \ ftype* x0; \ inc_t incx0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -81,8 +80,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)x, *incx, x0, incx0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_syr2.c b/frame/compat/bla_syr2.c index dbae670278..7050c04883 100644 --- a/frame/compat/bla_syr2.c +++ b/frame/compat/bla_syr2.c @@ -57,7 +57,6 @@ void PASTEF77(ch,blasname) \ ftype* y0; \ inc_t incx0; \ inc_t incy0; \ - inc_t rs_a, cs_a; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -87,8 +86,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)y, *incy, y0, incy0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ diff --git a/frame/compat/bla_syr2k.c b/frame/compat/bla_syr2k.c index 7e611b1d6d..2b26171b63 100644 --- a/frame/compat/bla_syr2k.c +++ b/frame/compat/bla_syr2k.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -57,9 +61,6 @@ void PASTEF77(ch,blasname) \ uplo_t blis_uploc; \ trans_t blis_transa; \ dim_t m0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -91,17 +92,17 @@ void PASTEF77(ch,blasname) \ blis_transa = BLIS_TRANSPOSE; \ } \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -124,6 +125,117 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* The real domain ssyr2k and dsyr2k in netlib BLAS treat a trans value + of 'C' (conjugate-transpose) as 'T' (transpose only). So, we have + to go out of our way a little to support this behavior. */ \ + if ( bli_is_real( PASTEMAC(ch,type) ) && \ + bli_is_conjtrans( blis_transa ) ) \ + { \ + blis_transa = BLIS_TRANSPOSE; \ + } \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const trans_t transb = blis_transa; \ + const struc_t strucc = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( transb, m0, k0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( syr2k, syr2k ) #endif diff --git a/frame/compat/bla_syrk.c b/frame/compat/bla_syrk.c index 9c08dd06bf..4f3f153676 100644 --- a/frame/compat/bla_syrk.c +++ b/frame/compat/bla_syrk.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -56,8 +60,6 @@ void PASTEF77(ch,blasname) \ uplo_t blis_uploc; \ trans_t blis_transa; \ dim_t m0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_c, cs_c; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -88,15 +90,15 @@ void PASTEF77(ch,blasname) \ blis_transa = BLIS_TRANSPOSE; \ } \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_c = 1; \ - cs_c = *ldc; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -117,6 +119,106 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* The real domain ssyrk and dsyrk in netlib BLAS treat a trans value + of 'C' (conjugate-transpose) as 'T' (transpose only). So, we have + to go out of our way a little to support this behavior. */ \ + if ( bli_is_real( PASTEMAC(ch,type) ) && \ + bli_is_conjtrans( blis_transa ) ) \ + { \ + blis_transa = BLIS_TRANSPOSE; \ + } \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t strucc = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( syrk, syrk ) #endif diff --git a/frame/compat/bla_trmm.c b/frame/compat/bla_trmm.c index 116d2b8c4c..b77a60dd6a 100644 --- a/frame/compat/bla_trmm.c +++ b/frame/compat/bla_trmm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -59,8 +63,6 @@ void PASTEF77(ch,blasname) \ trans_t blis_transa; \ diag_t blis_diaga; \ dim_t m0, n0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -86,15 +88,15 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -116,6 +118,103 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( trmm, trmm ) #endif diff --git a/frame/compat/bla_trmv.c b/frame/compat/bla_trmv.c index ffb31b12f9..2821d4bfad 100644 --- a/frame/compat/bla_trmv.c +++ b/frame/compat/bla_trmv.c @@ -57,7 +57,6 @@ void PASTEF77(ch,blasname) \ dim_t m0; \ ftype* x0; \ inc_t incx0; \ - inc_t rs_a, cs_a; \ ftype* one_p; \ \ /* Initialize BLIS. */ \ @@ -89,8 +88,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)x, *incx, x0, incx0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Acquire a pointer to the global scalar constant BLIS_ONE. */ \ one_p = PASTEMAC(ch,1); \ diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 70597cc93c..9af008090a 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -59,8 +63,6 @@ void PASTEF77(ch,blasname) \ trans_t blis_transa; \ diag_t blis_diaga; \ dim_t m0, n0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -86,15 +88,15 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ @@ -116,6 +118,103 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif diff --git a/frame/compat/bla_trsv.c b/frame/compat/bla_trsv.c index 445059720c..91132934e5 100644 --- a/frame/compat/bla_trsv.c +++ b/frame/compat/bla_trsv.c @@ -57,7 +57,6 @@ void PASTEF77(ch,blasname) \ dim_t m0; \ ftype* x0; \ inc_t incx0; \ - inc_t rs_a, cs_a; \ ftype* one_p; \ \ /* Initialize BLIS. */ \ @@ -89,8 +88,8 @@ void PASTEF77(ch,blasname) \ bli_convert_blas_incv( m0, (ftype*)x, *incx, x0, incx0 ); \ \ /* Set the row and column strides of A. */ \ - rs_a = 1; \ - cs_a = *lda; \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ \ /* Acquire a pointer to the global scalar constant BLIS_ONE. */ \ one_p = PASTEMAC(ch,1); \ diff --git a/frame/compat/bli_blas.h b/frame/compat/bli_blas.h index f2c3f94957..c88a2e3c39 100644 --- a/frame/compat/bli_blas.h +++ b/frame/compat/bli_blas.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,10 +41,31 @@ #endif #endif // BLIS_ENABLE_CBLAS +// By default, if the BLAS compatibility layer is enabled, we define +// (include) all of the BLAS prototypes. However, if the user is +// #including "blis.h" and also #including another header that also +// declares the BLAS functions, then we provide an opportunity to +// #undefine the BLIS_ENABLE_BLAS_DEFS macro (see below). +#ifdef BLIS_ENABLE_BLAS +#define BLIS_ENABLE_BLAS_DEFS +#else +#undef BLIS_ENABLE_BLAS_DEFS +#endif + // Skip prototyping all of the BLAS if the BLAS test drivers are being // compiled. -#ifndef BLIS_VIA_BLASTEST -#ifdef BLIS_ENABLE_BLAS +#ifdef BLIS_VIA_BLASTEST +#undef BLIS_ENABLE_BLAS_DEFS +#endif + +// Skip prototyping all of the BLAS if the environment has defined the +// macro BLIS_DISABLE_BLAS_DEFS. +#ifdef BLIS_DISABLE_BLAS_DEFS +#undef BLIS_ENABLE_BLAS_DEFS +#endif + +// Begin including all BLAS prototypes. +#ifdef BLIS_ENABLE_BLAS_DEFS // -- System headers needed by BLAS compatibility layer -- @@ -78,6 +100,7 @@ #include "bla_lsame.h" #include "bla_xerbla.h" +#include "bla_xerbla_array.h" // -- Level-0 BLAS prototypes -- @@ -174,10 +197,31 @@ #include "bla_trmm_check.h" #include "bla_trsm_check.h" + +// -- BLAS extension prototypes -- + +// unique to BLIS + +#include "bla_axpby.h" + +// level-3 + +#include "bla_gemmt.h" +#include "bla_gemmt_check.h" + +// batch + +#include "bla_gemm_batch.h" + +// 3m + +#include "bla_gemm3m.h" +#include "bla_gemm3m_check.h" + + // -- Fortran-compatible APIs to BLIS functions -- #include "b77_thread.h" #endif // BLIS_ENABLE_BLAS -#endif // BLIS_VIA_BLASTEST diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.c b/frame/compat/cblas/f77_sub/f77_dot_sub.c index 6c06133f1d..0ca80464d3 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.c +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.c @@ -35,6 +35,7 @@ #include "blis.h" #include "f77_dot_sub.h" +#ifdef BLIS_ENABLE_CBLAS // // Define CBLAS subrotine wrapper interfaces. @@ -58,9 +59,42 @@ void PASTEF773(ch,blasname,chc,sub) \ ); \ } -#ifdef BLIS_ENABLE_CBLAS -INSERT_GENTFUNCDOT_BLAS( dot, NULL ) +INSERT_GENTFUNCDOTR_BLAS( dot, NULL ) + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + +INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) +#else + +// +// Define CBLAS subrotine wrapper interfaces for complex types. +// For the "intel" complex return type, pass a hidden first parameter +// (by address). +// +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +void PASTEF773(ch,blasname,chc,sub) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* rval \ + ) \ +{ \ + PASTEF772(ch,blasname,chc) \ + ( \ + rval, \ + n, \ + x, incx, \ + y, incy \ + ); \ +} + +INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) + +#endif // -- "Black sheep" dot product function definitions -- @@ -75,7 +109,7 @@ void PASTEF772(sds,dot,sub) float* rval ) { - *rval = *sb + PASTEF77(sds,dot) + *rval = PASTEF77(sds,dot) ( n, sb, diff --git a/frame/compat/cblas/src/cblas.h b/frame/compat/cblas/src/cblas.h index 85778c8a48..22399ac8d4 100644 --- a/frame/compat/cblas/src/cblas.h +++ b/frame/compat/cblas/src/cblas.h @@ -1,3 +1,4 @@ + #ifndef CBLAS_H #define CBLAS_H #include @@ -575,6 +576,98 @@ void BLIS_EXPORT_BLAS cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, void BLIS_EXPORT_BLAS cblas_xerbla(f77_int p, const char *rout, const char *form, ...); + +/* + * =========================================================================== + * BLAS Extension prototypes + * =========================================================================== + */ + +// -- APIs to operations unique to BLIS -- + +void BLIS_EXPORT_BLAS cblas_saxpby(f77_int N, float alpha, const float *X, + f77_int incX, float beta, float *Y, f77_int incY); +void BLIS_EXPORT_BLAS cblas_daxpby(f77_int N, double alpha, const double *X, + f77_int incX, double beta, double *Y, f77_int incY); +void BLIS_EXPORT_BLAS cblas_caxpby(f77_int N, const void *alpha, + const void *X, f77_int incX, const void* beta, + void *Y, f77_int incY); +void BLIS_EXPORT_BLAS cblas_zaxpby(f77_int N, const void *alpha, + const void *X, f77_int incX, const void *beta, + void *Y, f77_int incY); + +// -- APIs to level-3-like operations -- + +void BLIS_EXPORT_BLAS cblas_sgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, float alpha, const float *A, + f77_int lda, const float *B, f77_int ldb, + float beta, float *C, f77_int ldc); +void BLIS_EXPORT_BLAS cblas_dgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, double alpha, const double *A, + f77_int lda, const double *B, f77_int ldb, + double beta, double *C, f77_int ldc); +void BLIS_EXPORT_BLAS cblas_cgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc); +void BLIS_EXPORT_BLAS cblas_zgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc); + +// -- Batch APIs -- + +void BLIS_EXPORT_BLAS cblas_sgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const float *alpha_array, const float **A, + f77_int *lda_array, const float **B, f77_int *ldb_array, + const float *beta_array, float **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); +void BLIS_EXPORT_BLAS cblas_dgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const double *alpha_array, + const double **A,f77_int *lda_array, + const double **B, f77_int *ldb_array, + const double *beta_array, double **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); +void BLIS_EXPORT_BLAS cblas_cgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, const void **A, + f77_int *lda_array, const void **B, f77_int *ldb_array, + const void *beta_array, void **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); +void BLIS_EXPORT_BLAS cblas_zgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, const void **A, + f77_int *lda_array, const void **B, f77_int *ldb_array, + const void *beta_array, void **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); + +// -- 3m APIs -- + +void BLIS_EXPORT_BLAS cblas_cgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, + f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc); +void BLIS_EXPORT_BLAS cblas_zgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, + f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc); + #ifdef __cplusplus } #endif diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index fcdd946df5..acb354aaf4 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -1,12 +1,46 @@ /* - * cblas_f77.h - * Written by Keita Teranishi - * - * Updated by Jeff Horner - * Merged cblas_f77.h and cblas_fortran_header.h - * - * (Heavily hacked down from the original) - */ + cblas_f77.h + Written by Keita Teranishi + + Updated by Jeff Horner + Merged cblas_f77.h and cblas_fortran_header.h + + (Heavily hacked down from the original) +*/ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ #ifndef CBLAS_F77_H #define CBLAS_F77_H @@ -163,5 +197,26 @@ #define F77_zsyr2k zsyr2k_ #define F77_ztrmm ztrmm_ #define F77_ztrsm ztrsm_ +/* +* BLAS extensions +*/ +#define F77_saxpby saxpby_ +#define F77_daxpby daxpby_ +#define F77_caxpby caxpby_ +#define F77_zaxpby zaxpby_ + +#define F77_sgemmt sgemmt_ +#define F77_dgemmt dgemmt_ +#define F77_cgemmt cgemmt_ +#define F77_zgemmt zgemmt_ + +#define F77_sgemm_batch sgemm_batch_ +#define F77_dgemm_batch dgemm_batch_ +#define F77_cgemm_batch cgemm_batch_ +#define F77_zgemm_batch zgemm_batch_ + +#define F77_cgemm3m cgemm3m_ +#define F77_zgemm3m zgemm3m_ + #endif /* CBLAS_F77_H */ diff --git a/frame/compat/cblas/src/cblas_sgemm.c b/frame/compat/cblas/src/cblas_sgemm.c index 89d0f07a88..bf40b9c0d9 100644 --- a/frame/compat/cblas/src/cblas_sgemm.c +++ b/frame/compat/cblas/src/cblas_sgemm.c @@ -7,6 +7,8 @@ * Written by Keita Teranishi * 4/8/1998 * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * */ #include "cblas.h" @@ -17,12 +19,12 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, f77_int lda, const float *B, f77_int ldb, float beta, float *C, f77_int ldc) { - char TA, TB; + char TA, TB; #ifdef F77_CHAR F77_CHAR F77_TA, F77_TB; #else - #define F77_TA &TA - #define F77_TB &TB + #define F77_TA &TA + #define F77_TB &TB #endif #ifdef F77_INT @@ -36,7 +38,7 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, #define F77_ldb ldb #define F77_ldc ldc #endif - + extern int CBLAS_CallFromC; extern int RowMajorStrg; RowMajorStrg = 0; @@ -46,9 +48,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransA == CblasTrans) TA='T'; else if ( TransA == CblasConjTrans ) TA='C'; else if ( TransA == CblasNoTrans ) TA='N'; - else + else { - cblas_xerbla(2, "cblas_sgemm", + cblas_xerbla(2, "cblas_sgemm", "Illegal TransA setting, %d\n", TransA); CBLAS_CallFromC = 0; RowMajorStrg = 0; @@ -58,9 +60,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransB == CblasTrans) TB='T'; else if ( TransB == CblasConjTrans ) TB='C'; else if ( TransB == CblasNoTrans ) TB='N'; - else + else { - cblas_xerbla(3, "cblas_sgemm", + cblas_xerbla(3, "cblas_sgemm", "Illegal TransB setting, %d\n", TransB); CBLAS_CallFromC = 0; RowMajorStrg = 0; @@ -79,9 +81,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransA == CblasTrans) TB='T'; else if ( TransA == CblasConjTrans ) TB='C'; else if ( TransA == CblasNoTrans ) TB='N'; - else + else { - cblas_xerbla(2, "cblas_sgemm", + cblas_xerbla(2, "cblas_sgemm", "Illegal TransA setting, %d\n", TransA); CBLAS_CallFromC = 0; RowMajorStrg = 0; @@ -90,10 +92,10 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransB == CblasTrans) TA='T'; else if ( TransB == CblasConjTrans ) TA='C'; else if ( TransB == CblasNoTrans ) TA='N'; - else + else { - cblas_xerbla(2, "cblas_sgemm", - "Illegal TransA setting, %d\n", TransA); + cblas_xerbla(2, "cblas_sgemm", + "Illegal TransB setting, %d\n", TransB); CBLAS_CallFromC = 0; RowMajorStrg = 0; return; @@ -104,7 +106,7 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, #endif F77_sgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B, &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); - } else + } else cblas_xerbla(1, "cblas_sgemm", "Illegal Order setting, %d\n", Order); CBLAS_CallFromC = 0; diff --git a/frame/compat/cblas/src/cblas_zgemm.c b/frame/compat/cblas/src/cblas_zgemm.c index e50de22054..8e08c20312 100644 --- a/frame/compat/cblas/src/cblas_zgemm.c +++ b/frame/compat/cblas/src/cblas_zgemm.c @@ -104,7 +104,7 @@ void cblas_zgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, F77_zgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, (dcomplex*)alpha, (dcomplex*)B, &F77_ldb, (dcomplex*)A, &F77_lda, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); } - else cblas_xerbla(1, "cblas_zgemm", "Illegal Order setting, %d\n", Order); + else cblas_xerbla(1, "cblas_zgemm", "Illegal Order setting, %d\n", Order); CBLAS_CallFromC = 0; RowMajorStrg = 0; return; diff --git a/frame/compat/cblas/src/extra/cblas_caxpby.c b/frame/compat/cblas/src/extra/cblas_caxpby.c new file mode 100644 index 0000000000..e8400d91be --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_caxpby.c @@ -0,0 +1,27 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_caxpby.c + * + * The program is a C interface to caxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_caxpby( f77_int N, const void *alpha, + const void *X, f77_int incX, + const void *beta, + void *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_caxpby( &F77_N, (scomplex*)alpha, (scomplex*)X, &F77_incX, (scomplex*)beta, (scomplex*)Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_cgemm3m.c b/frame/compat/cblas/src/extra/cblas_cgemm3m.c new file mode 100644 index 0000000000..514e525450 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_cgemm3m.c @@ -0,0 +1,115 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_cgemm3m.c + * + * This program is a C interface to cgemm3m. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, + f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc) +{ + char TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB; +#else + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + + if( Order == CblasColMajor ) + { + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_cgemm3m", "Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_cgemm3m", "Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemm3m(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, (scomplex*)alpha, (scomplex*)A, + &F77_lda, (scomplex*)B, &F77_ldb, (scomplex*)beta, (scomplex*)C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(2, "cblas_cgemm3m", "Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_cgemm3m", "Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + + F77_cgemm3m(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, (scomplex*)alpha, (scomplex*)B, + &F77_ldb, (scomplex*)A, &F77_lda, (scomplex*)beta, (scomplex*)C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_cgemm3m", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_cgemm_batch.c b/frame/compat/cblas/src/extra/cblas_cgemm_batch.c new file mode 100644 index 0000000000..18dd0bad58 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_cgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_cgemm_batch.c + * This program is a C interface to cgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, + const void **A_array, f77_int *lda_array, + const void **B_array, f77_int *ldb_array, + const void *beta_array, + void **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_cgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_cgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_cgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + (const scomplex*)alpha_array, + (const scomplex**)A_array, F77_lda, + (const scomplex**)B_array, F77_ldb, + (const scomplex*)beta_array, + (scomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_cgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_cgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_cgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + (const scomplex*)alpha_array, + (const scomplex**)B_array, F77_ldb, + (const scomplex**)A_array, F77_lda, + (const scomplex*)beta_array, + (scomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_cgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_cgemmt.c b/frame/compat/cblas/src/extra/cblas_cgemmt.c new file mode 100644 index 0000000000..79d18f0418 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_cgemmt.c @@ -0,0 +1,166 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + cblas_cgemmt.c + Based off of cblas_cgemm.c. +*/ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int K, + const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc) +{ + char UL, TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_UL, F77_TA, F77_TB; +#else + #define F77_UL &UL + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_cgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_cgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_cgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, (scomplex*)alpha, (scomplex*)A, + &F77_lda, (scomplex*)B, &F77_ldb, (scomplex*)beta, (scomplex*)C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(2, "cblas_cgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_cgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_cgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, (scomplex*)alpha, (scomplex*)B, + &F77_ldb, (scomplex*)A, &F77_lda, (scomplex*)beta, (scomplex*)C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_cgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_daxpby.c b/frame/compat/cblas/src/extra/cblas_daxpby.c new file mode 100644 index 0000000000..8fbea4d5a2 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_daxpby.c @@ -0,0 +1,26 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_daxpby.c + * + * The program is a C interface to daxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_daxpby( f77_int N, double alpha, + const double *X, f77_int incX, + double beta, + double *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_daxpby( &F77_N, &alpha, X, &F77_incX, &beta, Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_dgemm_batch.c b/frame/compat/cblas/src/extra/cblas_dgemm_batch.c new file mode 100644 index 0000000000..a2bed3b1a3 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_dgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_dgemm_batch.c + * This program is a C interface to dgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const double *alpha_array, + const double **A_array, f77_int *lda_array, + const double **B_array, f77_int *ldb_array, + const double *beta_array, + double **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_dgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_dgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_dgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + alpha_array, + A_array, F77_lda, + B_array, F77_ldb, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_dgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_dgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_dgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + alpha_array, + B_array, F77_ldb, + A_array, F77_lda, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_dgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_dgemmt.c b/frame/compat/cblas/src/extra/cblas_dgemmt.c new file mode 100644 index 0000000000..8677e02b78 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_dgemmt.c @@ -0,0 +1,166 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + cblas_dgemmt.c + Based off of cblas_dgemm.c. +*/ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int K, + double alpha, const double *A, + f77_int lda, const double *B, f77_int ldb, + double beta, double *C, f77_int ldc) +{ + char UL, TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_UL, F77_TA, F77_TB; +#else + #define F77_UL &UL + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_dgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_dgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_dgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_dgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, &alpha, A, + &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(2, "cblas_dgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_dgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_dgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_dgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, &alpha, B, + &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_dgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_saxpby.c b/frame/compat/cblas/src/extra/cblas_saxpby.c new file mode 100644 index 0000000000..6852821230 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_saxpby.c @@ -0,0 +1,28 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_saxpby.c + * + * The program is a C interface to saxpby. + * It calls the fortran wrapper before calling saxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_saxpby( f77_int N, float alpha, + const float *X, f77_int incX, + float beta, + float *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_saxpby( &F77_N, &alpha, X, &F77_incX, &beta, Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_sgemm_batch.c b/frame/compat/cblas/src/extra/cblas_sgemm_batch.c new file mode 100644 index 0000000000..3e8517db28 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_sgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_sgemm_batch.c + * This program is a C interface to sgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const float *alpha_array, + const float **A_array, f77_int *lda_array, + const float **B_array, f77_int *ldb_array, + const float *beta_array, + float **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_sgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_sgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_sgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + alpha_array, + A_array, F77_lda, + B_array, F77_ldb, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_sgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_sgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_sgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + alpha_array, + B_array, F77_ldb, + A_array, F77_lda, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_sgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_sgemmt.c b/frame/compat/cblas/src/extra/cblas_sgemmt.c new file mode 100644 index 0000000000..abe5ae857d --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_sgemmt.c @@ -0,0 +1,166 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + cblas_sgemmt.c + Based off of cblas_sgemm.c. +*/ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int K, + float alpha, const float *A, + f77_int lda, const float *B, f77_int ldb, + float beta, float *C, f77_int ldc) +{ + char UL, TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_UL, F77_TA, F77_TB; +#else + #define F77_UL &UL + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_sgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_sgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_sgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_sgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, &alpha, A, + &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(2, "cblas_sgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_sgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_sgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_sgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, &alpha, B, + &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_sgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_zaxpby.c b/frame/compat/cblas/src/extra/cblas_zaxpby.c new file mode 100644 index 0000000000..483607ec9b --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_zaxpby.c @@ -0,0 +1,27 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_zaxpby.c + * + * The program is a C interface to zaxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zaxpby( f77_int N, const void *alpha, + const void *X, f77_int incX, + const void *beta, + void *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zaxpby( &F77_N, (dcomplex*)alpha, (dcomplex*)X, &F77_incX, (dcomplex*)beta, (dcomplex*)Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_zgemm3m.c b/frame/compat/cblas/src/extra/cblas_zgemm3m.c new file mode 100644 index 0000000000..8be4278b42 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_zgemm3m.c @@ -0,0 +1,113 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_zgemm3m.c + * + * This program is a C interface to zgemm3m. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, + f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc) +{ + char TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB; +#else + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + + if( Order == CblasColMajor ) + { + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_zgemm3m", "Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_zgemm3m", "Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + F77_zgemm3m(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, (dcomplex*)alpha, (dcomplex*)A, + &F77_lda, (dcomplex*)B, &F77_ldb, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(2, "cblas_zgemm3m", "Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_zgemm3m", "Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemm3m(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, (dcomplex*)alpha, (dcomplex*)B, + &F77_ldb, (dcomplex*)A, &F77_lda, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zgemm3m", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_zgemm_batch.c b/frame/compat/cblas/src/extra/cblas_zgemm_batch.c new file mode 100644 index 0000000000..2d188a9f00 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_zgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_zgemm_batch.c + * This program is a C interface to zgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, + const void **A_array, f77_int *lda_array, + const void **B_array, f77_int *ldb_array, + const void *beta_array, + void **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_zgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_zgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_zgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + (const dcomplex*)alpha_array, + (const dcomplex**)A_array, F77_lda, + (const dcomplex**)B_array, F77_ldb, + (const dcomplex*)beta_array, + (dcomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_zgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_zgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_zgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + (const dcomplex*)alpha_array, + (const dcomplex**)B_array, F77_ldb, + (const dcomplex**)A_array, F77_lda, + (const dcomplex*)beta_array, + (dcomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_zgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_zgemmt.c b/frame/compat/cblas/src/extra/cblas_zgemmt.c new file mode 100644 index 0000000000..d3d1fa96a7 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_zgemmt.c @@ -0,0 +1,166 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + cblas_zgemmt.c + Based off of cblas_zgemm.c. +*/ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int K, + const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc) +{ + char UL, TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_UL, F77_TA, F77_TB; +#else + #define F77_UL &UL + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_zgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_zgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, (dcomplex*)alpha, (dcomplex*)A, + &F77_lda, (dcomplex*)B, &F77_ldb, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(2, "cblas_zgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_zgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_zgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemmt(F77_UL, F77_TA, F77_TB, &F77_M, &F77_K, (dcomplex*)alpha, (dcomplex*)B, + &F77_ldb, (dcomplex*)A, &F77_lda, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/check/bla_gemm3m_check.h b/frame/compat/check/bla_gemm3m_check.h new file mode 100644 index 0000000000..f565b5d29a --- /dev/null +++ b/frame/compat/check/bla_gemm3m_check.h @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef BLIS_ENABLE_BLAS + +#define bla_gemm3m_check( dt_str, op_str, transa, transb, m, n, k, lda, ldb, ldc ) \ +{ \ + f77_int info = 0; \ + f77_int nota, notb; \ + f77_int conja, conjb; \ + f77_int ta, tb; \ + f77_int nrowa, nrowb; \ +\ + nota = PASTEF770(lsame)( transa, "N", (ftnlen)1, (ftnlen)1 ); \ + notb = PASTEF770(lsame)( transb, "N", (ftnlen)1, (ftnlen)1 ); \ + conja = PASTEF770(lsame)( transa, "C", (ftnlen)1, (ftnlen)1 ); \ + conjb = PASTEF770(lsame)( transb, "C", (ftnlen)1, (ftnlen)1 ); \ + ta = PASTEF770(lsame)( transa, "T", (ftnlen)1, (ftnlen)1 ); \ + tb = PASTEF770(lsame)( transb, "T", (ftnlen)1, (ftnlen)1 ); \ +\ + if ( nota ) { nrowa = *m; } \ + else { nrowa = *k; } \ + if ( notb ) { nrowb = *k; } \ + else { nrowb = *n; } \ +\ + if ( !nota && !conja && !ta ) \ + info = 1; \ + else if ( !notb && !conjb && !tb ) \ + info = 2; \ + else if ( *m < 0 ) \ + info = 3; \ + else if ( *n < 0 ) \ + info = 4; \ + else if ( *k < 0 ) \ + info = 5; \ + else if ( *lda < bli_max( 1, nrowa ) ) \ + info = 8; \ + else if ( *ldb < bli_max( 1, nrowb ) ) \ + info = 10; \ + else if ( *ldc < bli_max( 1, *m ) ) \ + info = 13; \ +\ + if ( info != 0 ) \ + { \ + char func_str[ BLIS_MAX_BLAS_FUNC_STR_LENGTH ]; \ +\ + sprintf( func_str, "%s%-5s", dt_str, op_str ); \ +\ + bli_string_mkupper( func_str ); \ +\ + PASTEF770(xerbla)( func_str, &info, (ftnlen)6 ); \ +\ + return; \ + } \ +} + +#endif diff --git a/frame/compat/check/bla_gemmt_check.h b/frame/compat/check/bla_gemmt_check.h new file mode 100644 index 0000000000..93908e07db --- /dev/null +++ b/frame/compat/check/bla_gemmt_check.h @@ -0,0 +1,92 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef BLIS_ENABLE_BLAS + +#define bla_gemmt_check( dt_str, op_str, uploc, transa, transb, m, k, lda, ldb, ldc ) \ +{ \ + f77_int info = 0; \ + f77_int nota, notb; \ + f77_int conja, conjb; \ + f77_int ta, tb; \ + f77_int lower, upper; \ + f77_int nrowa, nrowb; \ +\ + nota = PASTEF770(lsame)( transa, "N", (ftnlen)1, (ftnlen)1 ); \ + notb = PASTEF770(lsame)( transb, "N", (ftnlen)1, (ftnlen)1 ); \ + conja = PASTEF770(lsame)( transa, "C", (ftnlen)1, (ftnlen)1 ); \ + conjb = PASTEF770(lsame)( transb, "C", (ftnlen)1, (ftnlen)1 ); \ + ta = PASTEF770(lsame)( transa, "T", (ftnlen)1, (ftnlen)1 ); \ + tb = PASTEF770(lsame)( transb, "T", (ftnlen)1, (ftnlen)1 ); \ +\ + lower = PASTEF770(lsame)( uploc, "L", (ftnlen)1, (ftnlen)1 ); \ + upper = PASTEF770(lsame)( uploc, "U", (ftnlen)1, (ftnlen)1 ); \ +\ + if ( nota ) { nrowa = *m; } \ + else { nrowa = *k; } \ + if ( notb ) { nrowb = *k; } \ + else { nrowb = *m; } \ +\ + if ( !lower && !upper ) \ + info = 1; \ + else if ( !nota && !conja && !ta ) \ + info = 2; \ + else if ( !notb && !conjb && !tb ) \ + info = 3; \ + else if ( *m < 0 ) \ + info = 4; \ + else if ( *k < 0 ) \ + info = 5; \ + else if ( *lda < bli_max( 1, nrowa ) ) \ + info = 8; \ + else if ( *ldb < bli_max( 1, nrowb ) ) \ + info = 10; \ + else if ( *ldc < bli_max( 1, *m ) ) \ + info = 13; \ +\ + if ( info != 0 ) \ + { \ + char func_str[ BLIS_MAX_BLAS_FUNC_STR_LENGTH ]; \ +\ + sprintf( func_str, "%s%-5s", dt_str, op_str ); \ +\ + bli_string_mkupper( func_str ); \ +\ + PASTEF770(xerbla)( func_str, &info, (ftnlen)6 ); \ +\ + return; \ + } \ +} + +#endif diff --git a/frame/compat/extra/bla_axpby.c b/frame/compat/extra/bla_axpby.c new file mode 100644 index 0000000000..d96d75d74c --- /dev/null +++ b/frame/compat/extra/bla_axpby.c @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( axpby, axpbyv ) +#endif diff --git a/frame/compat/extra/bla_axpby.h b/frame/compat/extra/bla_axpby.h new file mode 100644 index 0000000000..ab2952be98 --- /dev/null +++ b/frame/compat/extra/bla_axpby.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROT +#define GENTPROT( ftype, ch, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROT_BLAS( axpby ) +#endif + diff --git a/frame/compat/extra/bla_gemm3m.c b/frame/compat/extra/bla_gemm3m.c new file mode 100644 index 0000000000..4533375f01 --- /dev/null +++ b/frame/compat/extra/bla_gemm3m.c @@ -0,0 +1,259 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blisname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + /* As a placeholder, invoke 1m since BLIS does no longer contains an + official 3m implementation. Note that we do this by inlining an + abbreviated version of bli_gemm_ex() so that we can bypass + consideration of sup, which doesn't make sense in this context. */ \ + { \ + cntx_t* cntx = bli_gks_query_ind_cntx( BLIS_1M, dt ); \ +\ + rntm_t rntm_l; \ + rntm_t* rntm = &rntm_l; \ + bli_rntm_init_from_global( rntm ); \ +\ + /* Note that we MUST disable sup handling since it could redirect + execution for some problem sizes to a non-3m implementation. */ \ + bli_rntm_disable_l3_sup( rntm ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + cntx, \ + rntm \ + ); \ + } \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blisname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + /* As a placeholder, invoke 1m since BLIS does no longer contains an + official 3m implementation. Note that we do this by inlining an + abbreviated version of bli_gemm_ex() so that we can bypass + consideration of sup, which doesn't make sense in this context. */ \ + { \ + cntx_t* cntx = bli_gks_query_ind_cntx( BLIS_1M, dt ); \ +\ + rntm_t rntm_l; \ + rntm_t* rntm = &rntm_l; \ + bli_rntm_init_from_global( &rntm_l ); \ +\ + /* This is probably not needed given that we performed BLAS-style + parameter checking above, but bli_gemm_check() is normally called + in the normal course of bli_gemm_ex(). */ \ + if ( bli_error_checking_is_enabled() ) \ + bli_gemm_check( &alphao, &ao, &bo, &betao, &co, cntx ); \ +\ + PASTEMAC(blisname,_front) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm, \ + NULL \ + ); \ + } \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNCCO_BLAS( gemm3m, gemm ) +#endif + diff --git a/frame/compat/extra/bla_gemm3m.h b/frame/compat/extra/bla_gemm3m.h new file mode 100644 index 0000000000..86b7277c88 --- /dev/null +++ b/frame/compat/extra/bla_gemm3m.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROTCO +#define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROTCO_BLAS( gemm3m ) +#endif + diff --git a/frame/compat/extra/bla_gemm_batch.c b/frame/compat/extra/bla_gemm_batch.c new file mode 100644 index 0000000000..4b2597e193 --- /dev/null +++ b/frame/compat/extra/bla_gemm_batch.c @@ -0,0 +1,251 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa_array, \ + const f77_char* transb_array, \ + const f77_int* m_array, \ + const f77_int* n_array, \ + const f77_int* k_array, \ + const ftype* alpha_array, \ + const ftype** a_array, const f77_int* lda_array, \ + const ftype** b_array, const f77_int* ldb_array, \ + const ftype* beta_array, \ + ftype** c_array, const f77_int* ldc_array, \ + const f77_int* group_count, \ + const f77_int* group_size \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + for ( f77_int gi = 0; gi < *group_count; gi++ ) \ + { \ + PASTEBLACHK(blisname) \ + ( \ + MKSTR(ch), \ + MKSTR(blisname), \ + transa_array+gi, \ + transb_array+gi, \ + m_array+gi, \ + n_array+gi, \ + k_array+gi, \ + lda_array+gi, \ + ldb_array+gi, \ + ldc_array+gi \ + ); \ + } \ +\ + f77_int idx = 0; \ +\ + for ( f77_int i = 0; i < *group_count; i++ ) \ + { \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( transa_array[i], &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( transb_array[i], &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( m_array[i], m0 ); \ + bli_convert_blas_dim1( n_array[i], n0 ); \ + bli_convert_blas_dim1( k_array[i], k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = lda_array[i]; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = ldb_array[i]; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = ldc_array[i]; \ +\ + for ( f77_int j = 0; j < group_size[i]; j++ ) \ + { \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)(alpha_array + i), \ + (ftype*)*(a_array + idx), rs_a, cs_a, \ + (ftype*)*(b_array + idx), rs_b, cs_b, \ + (ftype*)(beta_array + i), \ + (ftype*)*(c_array + idx), rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + idx++; \ + } \ + } \ +\ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa_array, \ + const f77_char* transb_array, \ + const f77_int* m_array, \ + const f77_int* n_array, \ + const f77_int* k_array, \ + const ftype* alpha_array, \ + const ftype** a_array, const f77_int* lda_array, \ + const ftype** b_array, const f77_int* ldb_array, \ + const ftype* beta_array, \ + ftype** c_array, const f77_int* ldc_array, \ + const f77_int* group_count, \ + const f77_int* group_size ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + for ( f77_int gi = 0; gi < *group_count; gi++ ) \ + { \ + PASTEBLACHK(blisname) \ + ( \ + MKSTR(ch), \ + MKSTR(blisname), \ + transa_array+gi, \ + transb_array+gi, \ + m_array+gi, \ + n_array+gi, \ + k_array+gi, \ + lda_array+gi, \ + ldb_array+gi, \ + ldc_array+gi \ + ); \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + f77_int idx = 0, i, j; \ +\ + for ( i = 0; i < *group_count; i++ ) \ + { \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( transa_array[i], &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( transb_array[i], &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( m_array[i], m0 ); \ + bli_convert_blas_dim1( n_array[i], n0 ); \ + bli_convert_blas_dim1( k_array[i], k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = lda_array[i]; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = ldb_array[i]; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = ldc_array[i]; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)(alpha_array + i), &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)(beta_array + i), &betao ); \ +\ + for( j = 0; j < group_size[i]; j++ ) \ + { \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)*(a_array + idx), rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)*(b_array + idx), rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)*(c_array + idx), rs_c, cs_c, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + idx++; \ + } \ + } \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( gemm_batch, gemm ) +#endif + diff --git a/frame/compat/extra/bla_gemm_batch.h b/frame/compat/extra/bla_gemm_batch.h new file mode 100644 index 0000000000..f997f4b8ee --- /dev/null +++ b/frame/compat/extra/bla_gemm_batch.h @@ -0,0 +1,61 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROT +#define GENTPROT( ftype, ch, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa_array, \ + const f77_char* transb_array, \ + const f77_int* m_array, \ + const f77_int* n_array, \ + const f77_int* k_array, \ + const ftype* alpha_array, \ + const ftype** a_array, const f77_int* lda_array, \ + const ftype** b_array, const f77_int* ldb_array, \ + const ftype* beta_array, \ + ftype** c_array, const f77_int* ldc_array, \ + const f77_int* group_count, \ + const f77_int* group_size \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROT_BLAS( gemm_batch ) +#endif + diff --git a/frame/compat/extra/bla_gemmt.c b/frame/compat/extra/bla_gemmt.c new file mode 100644 index 0000000000..101cc6d134 --- /dev/null +++ b/frame/compat/extra/bla_gemmt.c @@ -0,0 +1,231 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + transb, \ + m, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_uploc, \ + blis_transa, \ + blis_transb, \ + m0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + transb, \ + m, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t strucc = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, m0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( gemmt, gemmt ) +#endif + diff --git a/frame/compat/extra/bla_gemmt.h b/frame/compat/extra/bla_gemmt.h new file mode 100644 index 0000000000..3bef5a8981 --- /dev/null +++ b/frame/compat/extra/bla_gemmt.h @@ -0,0 +1,60 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROT +#define GENTPROT( ftype, ch, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROT_BLAS( gemmt ) +#endif + diff --git a/frame/compat/f2c/bla_xerbla.h b/frame/compat/f2c/bla_xerbla.h index 44c168e584..f9f0a46410 100644 --- a/frame/compat/f2c/bla_xerbla.h +++ b/frame/compat/f2c/bla_xerbla.h @@ -34,6 +34,6 @@ #ifdef BLIS_ENABLE_BLAS -BLIS_EXPORT_BLAS int PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); +BLIS_EXPORT_BLAS BLIS_OVERRIDABLE int PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); #endif diff --git a/frame/compat/f2c/bla_xerbla_array.c b/frame/compat/f2c/bla_xerbla_array.c new file mode 100644 index 0000000000..722bb29144 --- /dev/null +++ b/frame/compat/f2c/bla_xerbla_array.c @@ -0,0 +1,74 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_BLAS + +#define MAX_NUM_CHARS 32 + +int PASTEF770(xerbla_array)(const bla_character *srname_array, const bla_integer srname_len, const bla_integer *info) +{ + int i; +#if 1 + // 01234567890123456789012345678901 + char srname[ MAX_NUM_CHARS + 1 ] = " "; +#else + char srname[ MAX_NUM_CHARS + 1 ]; + + // Initialize srname to contain blank characters. + for ( i = 0; i < MAX_NUM_CHARS; ++i ) srname[i] = ' '; +#endif + + // Compute the number of chars to copy as the minimum of the length of + // srname_array and MAX_NUM_CHARS. + const int n_copy = bli_min( srname_len, MAX_NUM_CHARS ); + + // Copy over each element of srname_array. + for ( i = 0; i < n_copy; ++i ) + { + srname[i] = srname_array[i]; + } + + // NULL terminate. + srname[i] = '\0'; + + // Call xerbla_(). + PASTEF770(xerbla)( srname, info, ( ftnlen )srname_len ); + + return 0; +} + +#endif + diff --git a/frame/3/syr2k/bli_syr2k.h b/frame/compat/f2c/bla_xerbla_array.h similarity index 91% rename from frame/3/syr2k/bli_syr2k.h rename to frame/compat/f2c/bla_xerbla_array.h index 680e6e3997..6a4b4e0598 100644 --- a/frame/3/syr2k/bli_syr2k.h +++ b/frame/compat/f2c/bla_xerbla_array.h @@ -32,5 +32,8 @@ */ -#include "bli_syr2k_front.h" +#ifdef BLIS_ENABLE_BLAS +BLIS_EXPORT_BLAS int PASTEF770(xerbla_array)(const bla_character *srname, const bla_integer srname_len, const bla_integer *info); + +#endif diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index f7ef3718ba..b300b27622 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,6 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,6 +42,7 @@ // // -- Intel64 architectures -- + #ifdef BLIS_CONFIG_SKX CNTX_INIT_PROTS( skx ) #endif @@ -62,6 +64,12 @@ CNTX_INIT_PROTS( penryn ) // -- AMD64 architectures -- +#ifdef BLIS_CONFIG_ZEN3 +CNTX_INIT_PROTS( zen3 ) +#endif +#ifdef BLIS_CONFIG_ZEN2 +CNTX_INIT_PROTS( zen2 ) +#endif #ifdef BLIS_CONFIG_ZEN CNTX_INIT_PROTS( zen ) #endif @@ -80,6 +88,15 @@ CNTX_INIT_PROTS( bulldozer ) // -- ARM architectures -- +#ifdef BLIS_CONFIG_ARMSVE +CNTX_INIT_PROTS( armsve ) +#endif +#ifdef BLIS_CONFIG_A64FX +CNTX_INIT_PROTS( a64fx ) +#endif +#ifdef BLIS_CONFIG_FIRESTORM +CNTX_INIT_PROTS( firestorm ) +#endif #ifdef BLIS_CONFIG_THUNDERX2 CNTX_INIT_PROTS( thunderx2 ) #endif @@ -96,14 +113,20 @@ CNTX_INIT_PROTS( cortexa15 ) CNTX_INIT_PROTS( cortexa9 ) #endif -// -- IBM BG/Q -- +// -- IBM Power -- +#ifdef BLIS_CONFIG_POWER10 +CNTX_INIT_PROTS( power10 ) +#endif #ifdef BLIS_CONFIG_POWER9 CNTX_INIT_PROTS( power9 ) #endif #ifdef BLIS_CONFIG_POWER7 CNTX_INIT_PROTS( power7 ) #endif + +// -- IBM BG/Q -- + #ifdef BLIS_CONFIG_BGQ CNTX_INIT_PROTS( bgq ) #endif @@ -127,11 +150,27 @@ CNTX_INIT_PROTS( generic ) #ifdef BLIS_FAMILY_AMD64 #include "bli_family_amd64.h" #endif +#ifdef BLIS_FAMILY_AMD64_LEGACY +#include "bli_family_amd64_legacy.h" +#endif #ifdef BLIS_FAMILY_X86_64 #include "bli_family_x86_64.h" #endif +#ifdef BLIS_FAMILY_X86_64_NO_SKX +#include "bli_family_x86_64_no_skx.h" +#endif + +#ifdef BLIS_FAMILY_X86_64_NO_ZEN2 +#include "bli_family_x86_64_no_zen2.h" +#endif + +#ifdef BLIS_FAMILY_X86_64_NO_ZEN3 +#include "bli_family_x86_64_no_zen3.h" +#endif + // -- Intel64 architectures -- + #ifdef BLIS_FAMILY_SKX #include "bli_family_skx.h" #endif @@ -153,6 +192,12 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- +#ifdef BLIS_FAMILY_ZEN3 +#include "bli_family_zen3.h" +#endif +#ifdef BLIS_FAMILY_ZEN2 +#include "bli_family_zen2.h" +#endif #ifdef BLIS_FAMILY_ZEN #include "bli_family_zen.h" #endif @@ -169,8 +214,28 @@ CNTX_INIT_PROTS( generic ) #include "bli_family_bulldozer.h" #endif +// -- ARM families -- +#ifdef BLIS_FAMILY_ARM64 +#include "bli_family_arm64.h" +#endif +#ifdef BLIS_FAMILY_ARM32 +#include "bli_family_arm32.h" +#endif + // -- ARM architectures -- +#ifdef BLIS_FAMILY_ARMSVE +#include "bli_family_armsve.h" +#endif +#ifdef BLIS_FAMILY_A64FX +#include "bli_family_a64fx.h" +#endif +#ifdef BLIS_FAMILY_FIRESTORM +#include "bli_family_firestorm.h" +#endif +#ifdef BLIS_FAMILY_THUNDERX2 +#include "bli_family_thunderx2.h" +#endif #ifdef BLIS_FAMILY_CORTEXA57 #include "bli_family_cortexa57.h" #endif @@ -184,14 +249,20 @@ CNTX_INIT_PROTS( generic ) #include "bli_family_cortexa9.h" #endif -// -- IBM BG/Q -- +// -- IBM Power -- +#ifdef BLIS_FAMILY_POWER10 +#include "bli_family_power10.h" +#endif #ifdef BLIS_FAMILY_POWER9 #include "bli_family_power9.h" #endif #ifdef BLIS_FAMILY_POWER7 #include "bli_family_power7.h" #endif + +// -- IBM BG/Q -- + #ifdef BLIS_FAMILY_BGQ #include "bli_family_bgq.h" #endif @@ -229,6 +300,9 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- +#ifdef BLIS_KERNELS_ZEN2 +#include "bli_kernels_zen2.h" +#endif #ifdef BLIS_KERNELS_ZEN #include "bli_kernels_zen.h" #endif @@ -247,6 +321,9 @@ CNTX_INIT_PROTS( generic ) // -- ARM architectures -- +#ifdef BLIS_KERNELS_ARMSVE +#include "bli_kernels_armsve.h" +#endif #ifdef BLIS_KERNELS_ARMV8A #include "bli_kernels_armv8a.h" #endif @@ -254,11 +331,20 @@ CNTX_INIT_PROTS( generic ) #include "bli_kernels_armv7a.h" #endif -// -- IBM BG/Q -- +// -- IBM Power -- +#ifdef BLIS_KERNELS_POWER10 +#include "bli_kernels_power10.h" +#endif +#ifdef BLIS_KERNELS_POWER9 +#include "bli_kernels_power9.h" +#endif #ifdef BLIS_KERNELS_POWER7 #include "bli_kernels_power7.h" #endif + +// -- IBM BG/Q -- + #ifdef BLIS_KERNELS_BGQ #include "bli_kernels_bgq.h" #endif diff --git a/frame/include/bli_arch_config_pre.h b/frame/include/bli_arch_config_pre.h index 1ab0561d83..86c5992306 100644 --- a/frame/include/bli_arch_config_pre.h +++ b/frame/include/bli_arch_config_pre.h @@ -69,7 +69,6 @@ void PASTEMAC2(cntx_init_,archname,BLIS_REF_SUFFIX) \ void PASTEMAC2(cntx_init_,archname,BLIS_IND_SUFFIX) \ ( \ ind_t method, \ - num_t dt, \ cntx_t* cntx \ ); diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index 46f78c27fd..0c75fb639a 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,11 +46,11 @@ // internally within BLIS as well as those exposed in the native BLAS-like BLIS // interface. #ifndef BLIS_INT_TYPE_SIZE -#ifdef BLIS_ARCH_64 -#define BLIS_INT_TYPE_SIZE 64 -#else -#define BLIS_INT_TYPE_SIZE 32 -#endif + #ifdef BLIS_ARCH_64 + #define BLIS_INT_TYPE_SIZE 64 + #else + #define BLIS_INT_TYPE_SIZE 32 + #endif #endif @@ -98,6 +99,26 @@ #define BLIS_ENABLE_MULTITHREADING #endif +// Enable the use of prime numbers of threads when requesting automatic thread +// factorization. When disabled, requesting a prime number of threads will +// result in a reduction (by one) of the number of threads, provided that the +// prime number exceeds a minimum threshold (see below). +#ifdef BLIS_ENABLE_AUTO_PRIME_NUM_THREADS + #undef BLIS_DISABLE_AUTO_PRIME_NUM_THREADS +#else + // Default behavior is disabled. + #undef BLIS_DISABLE_AUTO_PRIME_NUM_THREADS // In case user explicitly disabled. + #define BLIS_DISABLE_AUTO_PRIME_NUM_THREADS +#endif + +// Set the maximum requested number of threads that BLIS will accept from the +// user that may be prime. If a larger prime number of threads is requested, +// it will be reduced by one to allow for more efficient thread factorizations. +// This value will only be used if BLIS_ENABLE_AUTO_PRIME_NUM_THREADS is defined. +#ifndef BLIS_NT_MAX_PRIME + #define BLIS_NT_MAX_PRIME 11 +#endif + // -- MIXED DATATYPE SUPPORT --------------------------------------------------- @@ -128,16 +149,6 @@ #define BLIS_RELAX_MCNR_NCMR_CONSTRAINTS #endif -// Stay initialized after auto-initialization, unless and until the user -// explicitly calls bli_finalize(). -#ifdef BLIS_DISABLE_STAY_AUTO_INITIALIZED - #undef BLIS_ENABLE_STAY_AUTO_INITIALIZED -#else - // Default behavior is enabled. - #undef BLIS_ENABLE_STAY_AUTO_INITIALIZED // In case user explicitly enabled. - #define BLIS_ENABLE_STAY_AUTO_INITIALIZED -#endif - // -- BLAS COMPATIBILITY LAYER ------------------------------------------------- @@ -157,7 +168,19 @@ // C99 type "long int". Note that this ONLY affects integers used within the // BLAS compatibility layer. #ifndef BLIS_BLAS_INT_TYPE_SIZE -#define BLIS_BLAS_INT_TYPE_SIZE 32 + #define BLIS_BLAS_INT_TYPE_SIZE 32 +#endif + +// By default, the level-3 BLAS routines are implemented by directly calling +// the BLIS object API. Alternatively, they may first call the typed BLIS +// API, which will then call the object API. +//#define BLIS_BLAS3_CALLS_TAPI +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef BLIS_BLAS3_CALLS_OAPI +#else + // Default behavior is to call object API directly. + #undef BLIS_BLAS3_CALLS_OAPI // In case user explicitly enabled. + #define BLIS_BLAS3_CALLS_OAPI #endif @@ -176,5 +199,73 @@ #endif +// -- SHARED LIBRARY SYMBOL EXPORT --------------------------------------------- + +// When building shared libraries, we can control which symbols are exported for +// linking by external applications. BLIS annotates all function prototypes that +// are meant to be "public" with BLIS_EXPORT_BLIS (with BLIS_EXPORT_BLAS playing +// a similar role for BLAS compatibility routines). Which symbols are exported +// is controlled by the default symbol visibility, as specifed by the gcc option +// -fvisibility=[default|hidden]. The default for this option is 'default', or, +// "public", which, if allowed to stand, causes all symbols in BLIS to be +// linkable from the outside. But when compiling with -fvisibility=hidden, all +// symbols start out hidden (that is, restricted only for internal use by BLIS), +// with that setting overridden only for function prototypes or variable +// declarations that are annotated with BLIS_EXPORT_BLIS. + +#ifndef BLIS_EXPORT + #if !defined(BLIS_ENABLE_SHARED) + #define BLIS_EXPORT + #else + #if defined(_WIN32) || defined(__CYGWIN__) + #ifdef BLIS_IS_BUILDING_LIBRARY + #define BLIS_EXPORT __declspec(dllexport) + #else + #define BLIS_EXPORT __declspec(dllimport) + #endif + #elif defined(__GNUC__) && __GNUC__ >= 4 + #define BLIS_EXPORT __attribute__ ((visibility ("default"))) + #else + #define BLIS_EXPORT + #endif + #endif +#endif + +#define BLIS_EXPORT_BLIS BLIS_EXPORT +#define BLIS_EXPORT_BLAS BLIS_EXPORT +#define BLIS_EXPORT_ADDON BLIS_EXPORT + + +// -- OVERRIDABLE (WEAK) SYMBOLS ----------------------------------------------- + +// On Linux, functions called from a shared library can be overriden by the main +// program simply by providing a new definition. However, macOS uses a "two-level +// namespace" which causes calls to shared library functions to be tied to the +// library and not overridable. As a workaround, certain symbols can be defined +// as "weak" and are given lower preference during linking. +#ifndef BLIS_OVERRIDABLE +#if BLIS_OS_OSX +#define BLIS_OVERRIDABLE __attribute__((weak)) +#else +#define BLIS_OVERRIDABLE +#endif +#endif + + +// -- STATIC INLINE FUNCTIONS -------------------------------------------------- + +// C and C++ have different semantics for defining "inline" functions. In C, +// the keyword phrase "static inline" accomplishes this, though the "inline" +// is optional. In C++, the "inline" keyword is required and obviates "static" +// altogether. Why does this matter? While BLIS is compiled in C99, blis.h may +// be #included by a source file that is compiled with C++. +#ifdef __cplusplus + #define BLIS_INLINE inline +#else + //#define BLIS_INLINE static inline + #define BLIS_INLINE static +#endif + + #endif diff --git a/frame/include/bli_edge_case_macro_defs.h b/frame/include/bli_edge_case_macro_defs.h new file mode 100644 index 0000000000..70d97d5d10 --- /dev/null +++ b/frame/include/bli_edge_case_macro_defs.h @@ -0,0 +1,215 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_EDGE_CASE_MACRO_DEFS_H +#define BLIS_EDGE_CASE_MACRO_DEFS_H + +// +// Macros for edge-case handling within gemm microkernels. +// + +// -- Setup helper macros -- + +#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,alignment) \ +\ + PASTEMAC(ch,ctype)* restrict _beta = beta; \ + PASTEMAC(ch,ctype)* restrict _c = c; \ + const inc_t _rs_c = rs_c; \ + const inc_t _cs_c = cs_c; \ + PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ + __attribute__((aligned(alignment))); \ + const inc_t _rs_ct = row_major ? nr : 1; \ + const inc_t _cs_ct = row_major ? 1 : mr; + +#define GEMM_UKR_SETUP_CT_POST(ch) \ +\ + PASTEMAC(ch,ctype) _zero; \ + PASTEMAC(ch,set0s)( _zero ); \ + \ + if ( _use_ct ) \ + { \ + c = _ct; \ + rs_c = _rs_ct; \ + cs_c = _cs_ct; \ + beta = &_zero; \ + } + +// -- Setup macros -- + +#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \ +\ + /* Scenario 1: the ukernel contains assembly-level support only for its + IO preference (e.g. only row-oriented or only column-oriented IO). + Use a temporary microtile for the other two cases as well as edge + cases. */ \ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,1); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ +\ + /* Scenario 2: the ukernel contains assembly-level support for its IO + preference as well as its opposite via in-register transpose + (e.g. both row- and column-oriented IO). Use a temporary microtile + for the general stride case as well as edge cases. */ \ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,1); \ + const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ +\ + /* Scenario 3: Similar to (2) where the assembly region also supports + general stride I0. Use a temporary microtile only for edge cases. */ \ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,1); \ + const bool _use_ct = ( m != mr || n != nr ); \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ +\ + /* Scenario 4: Similar to (1), but uses temporary microtile to handle + cases where the pointer to the C microtile is not aligned. */ \ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,alignment); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr || \ + ( (uintptr_t)_c % alignment ) || \ + ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ + GEMM_UKR_SETUP_CT_POST(ch); + +// -- Flush macros -- + +#define GEMM_UKR_FLUSH_CT(ch) \ +\ + /* If we actually used the temporary microtile, accumulate it to the output + microtile. */ \ + if ( _use_ct ) \ + { \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + m, n, \ + _ct, _rs_ct, _cs_ct, \ + _beta, \ + _c, _rs_c, _cs_c \ + ); \ + } \ + + +// +// Macros for edge-case handling within gemmtrsm microkernels. +// + +// -- Setup helper macros -- + +#define GEMMTRSM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,alignment) \ +\ + PASTEMAC(ch,ctype)* restrict _c = c11; \ + const inc_t _rs_c = rs_c; \ + const inc_t _cs_c = cs_c; \ + PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ + __attribute__((aligned(alignment))); \ + const inc_t _rs_ct = row_major ? nr : 1; \ + const inc_t _cs_ct = row_major ? 1 : mr; + +#define GEMMTRSM_UKR_SETUP_CT_POST(ch) \ +\ + if ( _use_ct ) \ + { \ + c11 = _ct; \ + rs_c = _rs_ct; \ + cs_c = _cs_ct; \ + } + +// -- Setup macros -- + +#define GEMMTRSM_UKR_SETUP_CT(ch,mr,nr,row_major) \ +\ + /* Scenario 1: the ukernel contains assembly-level support only for its + IO preference (e.g. only row-oriented or only column-oriented IO). + Use a temporary microtile for the other two cases as well as edge + cases. */ \ + GEMMTRSM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,1); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMMTRSM_UKR_SETUP_CT_POST(ch); + +#define GEMMTRSM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ +\ + /* Scenario 2: the ukernel contains assembly-level support for its IO + preference as well as its opposite via in-register transpose + (e.g. both row- and column-oriented IO). Use a temporary microtile + for the general stride case as well as edge cases. */ \ + GEMMTRSM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,1); \ + const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMMTRSM_UKR_SETUP_CT_POST(ch); + +#define GEMMTRSM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ +\ + /* Scenario 3: Similar to (2) where the assembly region also supports + general stride I0. Use a temporary microtile only for edge cases. */ \ + GEMMTRSM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,1); \ + const bool _use_ct = ( m != mr || n != nr ); \ + GEMMTRSM_UKR_SETUP_CT_POST(ch); + +#define GEMMTRSM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ +\ + /* Scenario 4: Similar to (1), but uses temporary microtile to handle + cases where the pointer to the C microtile is not aligned. */ \ + GEMMTRSM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major,alignment); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr || \ + ( (uintptr_t)_c % alignment ) || \ + ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ + GEMMTRSM_UKR_SETUP_CT_POST(ch); + +// -- Flush macros -- + +#define GEMMTRSM_UKR_FLUSH_CT(ch) \ +\ + /* If we actually used the temporary microtile, use it to overwrite the + output microtile. Used by trsm. */ \ + if ( _use_ct ) \ + { \ + PASTEMAC(ch,copys_mxn) \ + ( \ + m, n, \ + _ct, _rs_ct, _cs_ct, \ + _c, _rs_c, _cs_c \ + ); \ + } \ + + +#endif + diff --git a/frame/include/bli_error_macro_defs.h b/frame/include/bli_error_macro_defs.h index a0c9ea6ab3..00d8acdcb8 100644 --- a/frame/include/bli_error_macro_defs.h +++ b/frame/include/bli_error_macro_defs.h @@ -35,12 +35,6 @@ #ifndef BLIS_ERROR_MACRO_DEFS_H #define BLIS_ERROR_MACRO_DEFS_H -// -- Error-related macros -- - -// Used to determine the size of the array of error strings. -#define BLIS_MAX_NUM_ERR_MSGS 200 -#define BLIS_MAX_ERR_MSG_LENGTH 200 - // Used to insert filenames and line numbers into error-checking code. #define bli_check_error_code( code ) \ bli_check_error_code_helper( code, __FILE__, __LINE__ ) diff --git a/frame/include/bli_genarray_macro_defs.h b/frame/include/bli_genarray_macro_defs.h index 23cee1064b..eb932c5582 100644 --- a/frame/include/bli_genarray_macro_defs.h +++ b/frame/include/bli_genarray_macro_defs.h @@ -128,6 +128,20 @@ arrayname[BLIS_NUM_FP_TYPES][BLIS_NUM_FP_TYPES] = \ +// -- One-operand macro (with custom prefix) -- + +#define GENARRAY_PREF(arrayname,prefix,op) \ +\ +arrayname[BLIS_NUM_FP_TYPES] = \ +{ \ + PASTECH2(prefix,s,op), \ + PASTECH2(prefix,c,op), \ + PASTECH2(prefix,d,op), \ + PASTECH2(prefix,z,op) \ +} + + + // -- Two-operand macros -- diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 47276b0764..011ebcdfbb 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -74,19 +74,35 @@ GENTFUNCCO( scomplex, float, c, s, blasname, blisname ) \ GENTFUNCCO( dcomplex, double, z, d, blasname, blisname ) -// -- Basic one-operand macro with conjugation (used only for dot, ger) -- +// -- Basic one-operand macro with conjugation (real funcs only, used only for dot, ger) -- -#define INSERT_GENTFUNCDOT_BLAS( blasname, blisname ) \ +#define INSERT_GENTFUNCDOTR_BLAS( blasname, blisname ) \ \ GENTFUNCDOT( float, s, , BLIS_NO_CONJUGATE, blasname, blisname ) \ -GENTFUNCDOT( double, d, , BLIS_NO_CONJUGATE, blasname, blisname ) \ +GENTFUNCDOT( double, d, , BLIS_NO_CONJUGATE, blasname, blisname ) + + +// -- Basic one-operand macro with conjugation (complex funcs only, used only for dot, ger) -- + + +#define INSERT_GENTFUNCDOTC_BLAS( blasname, blisname ) \ +\ GENTFUNCDOT( scomplex, c, c, BLIS_CONJUGATE, blasname, blisname ) \ GENTFUNCDOT( scomplex, c, u, BLIS_NO_CONJUGATE, blasname, blisname ) \ GENTFUNCDOT( dcomplex, z, c, BLIS_CONJUGATE, blasname, blisname ) \ GENTFUNCDOT( dcomplex, z, u, BLIS_NO_CONJUGATE, blasname, blisname ) +// -- Basic one-operand macro with conjugation (used only for dot, ger) -- + + +#define INSERT_GENTFUNCDOT_BLAS( blasname, blisname ) \ +\ +INSERT_GENTFUNCDOTR_BLAS( blasname, blisname ) \ +INSERT_GENTFUNCDOTC_BLAS( blasname, blisname ) + + // -- Basic one-operand macro with real projection -- @@ -225,6 +241,24 @@ GENTFUNCR( dcomplex, double, z, d, tfuncname, varname1, varname2, varname3, varn +// -- Basic one-operand macro with real domain only -- + +// -- (no auxiliary arguments) -- + +#define INSERT_GENTFUNCRO_BASIC0( tfuncname ) \ +\ +GENTFUNCRO( float, s, tfuncname ) \ +GENTFUNCRO( double, d, tfuncname ) \ + +// -- (one auxiliary argument) -- + +#define INSERT_GENTFUNCRO_BASIC( tfuncname, varname ) \ +\ +GENTFUNCRO( float, s, tfuncname, varname ) \ +GENTFUNCRO( double, d, tfuncname, varname ) \ + + + // -- Basic one-operand macro with complex domain only and real projection -- // -- (no auxiliary arguments) -- diff --git a/frame/include/bli_gentprot_macro_defs.h b/frame/include/bli_gentprot_macro_defs.h index f6aa70946f..3db9cdc480 100644 --- a/frame/include/bli_gentprot_macro_defs.h +++ b/frame/include/bli_gentprot_macro_defs.h @@ -74,19 +74,35 @@ GENTPROTCO( scomplex, float, c, s, blasname ) \ GENTPROTCO( dcomplex, double, z, d, blasname ) -// -- Basic one-operand macro with conjugation (used only for dot, ger) -- +// -- Basic one-operand macro with conjugation (real funcs only, used only for dot, ger) -- -#define INSERT_GENTPROTDOT_BLAS( blasname ) \ +#define INSERT_GENTPROTDOTR_BLAS( blasname ) \ \ GENTPROTDOT( float, s, , blasname ) \ -GENTPROTDOT( double, d, , blasname ) \ +GENTPROTDOT( double, d, , blasname ) + + +// -- Basic one-operand macro with conjugation (complex funcs only, used only for dot, ger) -- + + +#define INSERT_GENTPROTDOTC_BLAS( blasname ) \ +\ GENTPROTDOT( scomplex, c, c, blasname ) \ GENTPROTDOT( scomplex, c, u, blasname ) \ GENTPROTDOT( dcomplex, z, c, blasname ) \ GENTPROTDOT( dcomplex, z, u, blasname ) +// -- Basic one-operand macro with conjugation (used only for dot, ger) -- + + +#define INSERT_GENTPROTDOT_BLAS( blasname ) \ +\ +INSERT_GENTPROTDOTR_BLAS( blasname ) \ +INSERT_GENTPROTDOTC_BLAS( blasname ) + + // -- Basic one-operand macro with real projection -- diff --git a/frame/include/bli_kernel_macro_defs.h b/frame/include/bli_kernel_macro_defs.h index cea176a812..4de624f98c 100644 --- a/frame/include/bli_kernel_macro_defs.h +++ b/frame/include/bli_kernel_macro_defs.h @@ -38,14 +38,23 @@ // -- Define default threading parameters -------------------------------------- +// -- Conventional (large code path) values -- + +// These BLIS_THREAD_RATIO_? macros distort the amount of work in the m and n +// dimensions for the purposes of factorizing the total number of threads into +// ways of parallelism in the ic and jc loops. See bli_rntm.c to see how these +// macros are used. #ifndef BLIS_THREAD_RATIO_M -#define BLIS_THREAD_RATIO_M 2 +#define BLIS_THREAD_RATIO_M 1 #endif #ifndef BLIS_THREAD_RATIO_N #define BLIS_THREAD_RATIO_N 1 #endif +// These BLIS_THREAD_MAX_?R macros place a ceiling on the maximum amount of +// parallelism allowed when performing automatic factorization. See bli_rntm.c +// to see how these macros are used. #ifndef BLIS_THREAD_MAX_IR #define BLIS_THREAD_MAX_IR 1 #endif @@ -54,6 +63,26 @@ #define BLIS_THREAD_MAX_JR 4 #endif +#if 0 +// -- Skinny/small possibly-unpacked (sup code path) values -- + +#ifndef BLIS_THREAD_SUP_RATIO_M +#define BLIS_THREAD_SUP_RATIO_M 1 +#endif + +#ifndef BLIS_THREAD_SUP_RATIO_N +#define BLIS_THREAD_SUP_RATIO_N 2 +#endif + +#ifndef BLIS_THREAD_SUP_MAX_IR +#define BLIS_THREAD_SUP_MAX_IR 1 +#endif + +#ifndef BLIS_THREAD_SUP_MAX_JR +#define BLIS_THREAD_SUP_MAX_JR 8 +#endif +#endif + // -- Memory allocation -------------------------------------------------------- @@ -134,21 +163,21 @@ // When configuring with umbrella configuration families, this should be // set to the maximum number of registers across all sub-configurations in // the family. -#ifndef BLIS_SIMD_NUM_REGISTERS -#define BLIS_SIMD_NUM_REGISTERS 32 +#ifndef BLIS_SIMD_MAX_NUM_REGISTERS +#define BLIS_SIMD_MAX_NUM_REGISTERS 32 #endif // The maximum size (in bytes) of each SIMD vector. // When configuring with umbrella configuration families, this should be // set to the maximum SIMD size across all sub-configurations in the family. -#ifndef BLIS_SIMD_SIZE -#define BLIS_SIMD_SIZE 64 +#ifndef BLIS_SIMD_MAX_SIZE +#define BLIS_SIMD_MAX_SIZE 64 #endif // Alignment size (in bytes) needed by the instruction set for aligned // SIMD/vector instructions. #ifndef BLIS_SIMD_ALIGN_SIZE -#define BLIS_SIMD_ALIGN_SIZE BLIS_SIMD_SIZE +#define BLIS_SIMD_ALIGN_SIZE BLIS_SIMD_MAX_SIZE #endif // The maximum size in bytes of local stack buffers within macro-kernel @@ -159,25 +188,62 @@ // micro-tile footprint, even though the virtual micro-kernels will only // ever be writing to half (real or imaginary part) at a time. #ifndef BLIS_STACK_BUF_MAX_SIZE -#define BLIS_STACK_BUF_MAX_SIZE ( BLIS_SIMD_NUM_REGISTERS * \ - BLIS_SIMD_SIZE * 2 ) +#define BLIS_STACK_BUF_MAX_SIZE ( BLIS_SIMD_MAX_NUM_REGISTERS * \ + BLIS_SIMD_MAX_SIZE * 2 ) #endif // Alignment size used to align local stack buffers within macro-kernel // functions. +#ifndef BLIS_STACK_BUF_ALIGN_SIZE #define BLIS_STACK_BUF_ALIGN_SIZE BLIS_SIMD_ALIGN_SIZE +#endif // Alignment size used when allocating memory via BLIS_MALLOC_USER. // To disable heap alignment, set this to 1. +#ifndef BLIS_HEAP_ADDR_ALIGN_SIZE #define BLIS_HEAP_ADDR_ALIGN_SIZE BLIS_SIMD_ALIGN_SIZE +#endif // Alignment size used when sizing leading dimensions of memory allocated // via BLIS_MALLOC_USER. +#ifndef BLIS_HEAP_STRIDE_ALIGN_SIZE #define BLIS_HEAP_STRIDE_ALIGN_SIZE BLIS_SIMD_ALIGN_SIZE +#endif -// Alignment size used when allocating blocks to the internal memory +// Alignment sizes used when allocating blocks to the internal memory // pool, via BLIS_MALLOC_POOL. -#define BLIS_POOL_ADDR_ALIGN_SIZE BLIS_PAGE_SIZE +#ifndef BLIS_POOL_ADDR_ALIGN_SIZE_A +#define BLIS_POOL_ADDR_ALIGN_SIZE_A BLIS_PAGE_SIZE +#endif + +#ifndef BLIS_POOL_ADDR_ALIGN_SIZE_B +#define BLIS_POOL_ADDR_ALIGN_SIZE_B BLIS_PAGE_SIZE +#endif + +#ifndef BLIS_POOL_ADDR_ALIGN_SIZE_C +#define BLIS_POOL_ADDR_ALIGN_SIZE_C BLIS_PAGE_SIZE +#endif + +#ifndef BLIS_POOL_ADDR_ALIGN_SIZE_GEN +#define BLIS_POOL_ADDR_ALIGN_SIZE_GEN BLIS_PAGE_SIZE +#endif + +// Offsets from alignment specified by BLIS_POOL_ADDR_ALIGN_SIZE_*. +#ifndef BLIS_POOL_ADDR_OFFSET_SIZE_A +#define BLIS_POOL_ADDR_OFFSET_SIZE_A 0 +#endif + +#ifndef BLIS_POOL_ADDR_OFFSET_SIZE_B +#define BLIS_POOL_ADDR_OFFSET_SIZE_B 0 +#endif + +#ifndef BLIS_POOL_ADDR_OFFSET_SIZE_C +#define BLIS_POOL_ADDR_OFFSET_SIZE_C 0 +#endif + +#ifndef BLIS_POOL_ADDR_OFFSET_SIZE_GEN +#define BLIS_POOL_ADDR_OFFSET_SIZE_GEN 0 +#endif diff --git a/frame/include/bli_lang_defs.h b/frame/include/bli_lang_defs.h new file mode 100644 index 0000000000..8cf3f99862 --- /dev/null +++ b/frame/include/bli_lang_defs.h @@ -0,0 +1,111 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_LANG_DEFS_H +#define BLIS_LANG_DEFS_H + + +// -- Undefine restrict for C++ and C89/90 -- + +#ifdef __cplusplus + // Language is C++; define restrict as nothing. + #ifndef restrict + #define restrict + #endif +#elif __STDC_VERSION__ >= 199901L + // Language is C99 (or later); do nothing since restrict is recognized. +#else + // Language is pre-C99; define restrict as nothing. + #ifndef restrict + #define restrict + #endif +#endif + + +// -- Define typeof() operator if using non-GNU compiler -- + +#ifndef __GNUC__ + #define typeof __typeof__ +#else + #ifndef typeof + #define typeof __typeof__ + #endif +#endif + + +// -- BLIS Thread Local Storage Keyword -- + +// __thread for TLS is supported by GCC, CLANG, ICC, and IBMC. +// There is a small risk here as __GNUC__ can also be defined by some other +// compiler (other than ICC and CLANG which we know define it) that +// doesn't support __thread, as __GNUC__ is not quite unique to GCC. +// But the possibility of someone using such non-main-stream compiler +// for building BLIS is low. +#if defined(__GNUC__) || defined(__clang__) || defined(__ICC) || defined(__IBMC__) + #define BLIS_THREAD_LOCAL __thread +#else + #define BLIS_THREAD_LOCAL +#endif + + +// -- BLIS constructor/destructor function attribute -- + +// __attribute__((constructor/destructor)) is supported by GCC only. +// There is a small risk here as __GNUC__ can also be defined by some other +// compiler (other than ICC and CLANG which we know define it) that +// doesn't support this, as __GNUC__ is not quite unique to GCC. +// But the possibility of someone using such non-main-stream compiler +// for building BLIS is low. + +#if defined(__ICC) || defined(__INTEL_COMPILER) + // ICC defines __GNUC__ but doesn't support this + #define BLIS_ATTRIB_CTOR + #define BLIS_ATTRIB_DTOR +#elif defined(__clang__) + // CLANG supports __attribute__, but its documentation doesn't + // mention support for constructor/destructor. Compiling with + // clang and testing shows that it does support. + #define BLIS_ATTRIB_CTOR __attribute__((constructor)) + #define BLIS_ATTRIB_DTOR __attribute__((destructor)) +#elif defined(__GNUC__) + #define BLIS_ATTRIB_CTOR __attribute__((constructor)) + #define BLIS_ATTRIB_DTOR __attribute__((destructor)) +#else + #define BLIS_ATTRIB_CTOR + #define BLIS_ATTRIB_DTOR +#endif + + +#endif diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index c25d84c995..be45a12e3f 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,77 +37,6 @@ #define BLIS_MACRO_DEFS_H -// -- Undefine restrict for C++ and C89/90 -- - -#ifdef __cplusplus - // Language is C++; define restrict as nothing. - #ifndef restrict - #define restrict - #endif -#elif __STDC_VERSION__ >= 199901L - // Language is C99 (or later); do nothing since restrict is recognized. -#else - // Language is pre-C99; define restrict as nothing. - #ifndef restrict - #define restrict - #endif -#endif - - -// -- Define typeof() operator if using non-GNU compiler -- - -#ifndef __GNUC__ - #define typeof __typeof__ -#else - #ifndef typeof - #define typeof __typeof__ - #endif -#endif - - -// -- BLIS Thread Local Storage Keyword -- - -// __thread for TLS is supported by GCC, CLANG, ICC, and IBMC. -// There is a small risk here as __GNUC__ can also be defined by some other -// compiler (other than ICC and CLANG which we know define it) that -// doesn't support __thread, as __GNUC__ is not quite unique to GCC. -// But the possibility of someone using such non-main-stream compiler -// for building BLIS is low. -#if defined(__GNUC__) || defined(__clang__) || defined(__ICC) || defined(__IBMC__) - #define BLIS_THREAD_LOCAL __thread -#else - #define BLIS_THREAD_LOCAL -#endif - - -// -- BLIS constructor/destructor function attribute -- - -// __attribute__((constructor/destructor)) is supported by GCC only. -// There is a small risk here as __GNUC__ can also be defined by some other -// compiler (other than ICC and CLANG which we know define it) that -// doesn't support this, as __GNUC__ is not quite unique to GCC. -// But the possibility of someone using such non-main-stream compiler -// for building BLIS is low. - -#if defined(__ICC) || defined(__INTEL_COMPILER) - // ICC defines __GNUC__ but doesn't support this - #define BLIS_ATTRIB_CTOR - #define BLIS_ATTRIB_DTOR -#elif defined(__clang__) - // CLANG supports __attribute__, but its documentation doesn't - // mention support for constructor/destructor. Compiling with - // clang and testing shows that it does support. - #define BLIS_ATTRIB_CTOR __attribute__((constructor)) - #define BLIS_ATTRIB_DTOR __attribute__((destructor)) -#elif defined(__GNUC__) - #define BLIS_ATTRIB_CTOR __attribute__((constructor)) - #define BLIS_ATTRIB_DTOR __attribute__((destructor)) -#else - #define BLIS_ATTRIB_CTOR - #define BLIS_ATTRIB_DTOR -#endif - - // -- Concatenation macros -- #define BLIS_FUNC_PREFIX_STR "bli" @@ -140,6 +69,9 @@ #define PASTEBLACHK_(op) bla_ ## op ## _check #define PASTEBLACHK(op) PASTEBLACHK_(op) +#define PASTECH0_(op) op +#define PASTECH0(op) PASTECH0_(op) + #define PASTECH_(ch,op) ch ## op #define PASTECH(ch,op) PASTECH_(ch,op) @@ -166,6 +98,7 @@ #include "bli_gentprot_macro_defs.h" #include "bli_misc_macro_defs.h" +#include "bli_edge_case_macro_defs.h" #include "bli_param_macro_defs.h" #include "bli_obj_macro_defs.h" #include "bli_complex_macro_defs.h" diff --git a/frame/include/bli_misc_macro_defs.h b/frame/include/bli_misc_macro_defs.h index 6805960f25..120338beba 100644 --- a/frame/include/bli_misc_macro_defs.h +++ b/frame/include/bli_misc_macro_defs.h @@ -67,14 +67,14 @@ // round -static double bli_round( double a ) +BLIS_INLINE double bli_round( double a ) { return round( a ); } // round_to_mult -static guint_t bli_round_to_mult( guint_t val, guint_t mult ) +BLIS_INLINE guint_t bli_round_to_mult( guint_t val, guint_t mult ) { return ( guint_t ) ( ( ( ( guint_t )val + @@ -94,19 +94,19 @@ static guint_t bli_round_to_mult( guint_t val, guint_t mult ) // is_odd, is_even -static bool_t bli_is_odd( gint_t a ) +BLIS_INLINE bool bli_is_odd( gint_t a ) { - return ( a % 2 == 1 ); + return ( bool )( a % 2 == 1 ); } -static bool_t bli_is_even( gint_t a ) +BLIS_INLINE bool bli_is_even( gint_t a ) { - return ( a % 2 == 0 ); + return ( bool )( a % 2 == 0 ); } // swap_dims -static void bli_swap_dims( dim_t* dim1, dim_t* dim2 ) +BLIS_INLINE void bli_swap_dims( dim_t* dim1, dim_t* dim2 ) { dim_t temp = *dim1; *dim1 = *dim2; @@ -115,7 +115,7 @@ static void bli_swap_dims( dim_t* dim1, dim_t* dim2 ) // swap_incs -static void bli_swap_incs( inc_t* inc1, inc_t* inc2 ) +BLIS_INLINE void bli_swap_incs( inc_t* inc1, inc_t* inc2 ) { inc_t temp = *inc1; *inc1 = *inc2; @@ -124,7 +124,7 @@ static void bli_swap_incs( inc_t* inc1, inc_t* inc2 ) // toggle_bool -static void bli_toggle_bool( bool_t* b ) +BLIS_INLINE void bli_toggle_bool( bool* b ) { if ( *b == TRUE ) *b = FALSE; else *b = TRUE; diff --git a/frame/include/bli_oapi_ba.h b/frame/include/bli_oapi_ba.h index 3f0bfa35a8..dc17507d11 100644 --- a/frame/include/bli_oapi_ba.h +++ b/frame/include/bli_oapi_ba.h @@ -35,7 +35,12 @@ // This file defines macros used to allow the _oapi.c files to produce // object APIs that omit expert parameters. -// Define the macro to remove the function name suffix (in function +// Define a macro that allows the source code to determine which interface +// (basic or expert) we are compiling. +#undef BLIS_OAPI_BASIC +#define BLIS_OAPI_BASIC + +// Define the macro to omit a suffix from the function names (in function // definitions). #undef EX_SUF #define EX_SUF @@ -45,14 +50,10 @@ #undef BLIS_OAPI_EX_PARAMS #define BLIS_OAPI_EX_PARAMS -// Define the macro to declare local expert variables that are initialized +// Define the macro to add local expert variables that are initialized // to NULL. The "( void )" statements are to prevent unused variable // warnings by the compiler. #undef BLIS_OAPI_EX_DECLS #define BLIS_OAPI_EX_DECLS cntx_t* cntx = NULL; ( void )cntx; \ rntm_t* rntm = NULL; ( void )rntm; -// Define the macro to pass the local expert variables to another function. -//#undef BLIS_TAPI_EX_VARS -//#define BLIS_TAPI_EX_VARS - diff --git a/frame/include/bli_oapi_ex.h b/frame/include/bli_oapi_ex.h index 7acaf36230..0eb5eb2a1e 100644 --- a/frame/include/bli_oapi_ex.h +++ b/frame/include/bli_oapi_ex.h @@ -35,8 +35,13 @@ // This file defines macros used to allow the _oapi.c files to produce // object APIs that contain context parameters. -// Define the macro to add a suffix to the object API function names -// (in function definitions). +// Define a macro that allows the source code to determine which interface +// (basic or expert) we are compiling. +#undef BLIS_OAPI_EXPERT +#define BLIS_OAPI_EXPERT + +// Define the macro to add a suffix to the function names (in function +// definitions). #undef EX_SUF #define EX_SUF BLIS_OAPI_EX_SUF @@ -50,7 +55,3 @@ #undef BLIS_OAPI_EX_DECLS #define BLIS_OAPI_EX_DECLS -// Define the macro to pass the local expert variables to another function. -//#undef BLIS_TAPI_EX_VARS -//#define BLIS_TAPI_EX_VARS ,cntx, rntm - diff --git a/frame/include/bli_obj_macro_defs.h b/frame/include/bli_obj_macro_defs.h index e3eb2b8749..fe174202cf 100644 --- a/frame/include/bli_obj_macro_defs.h +++ b/frame/include/bli_obj_macro_defs.h @@ -6,6 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,556 +42,556 @@ // Info query -static num_t bli_obj_dt( obj_t* obj ) +BLIS_INLINE num_t bli_obj_dt( obj_t* obj ) { return ( num_t ) ( obj->info & BLIS_DATATYPE_BITS ); } -static bool_t bli_obj_is_float( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_float( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_dt( obj ) == BLIS_BITVAL_FLOAT_TYPE ); } -static bool_t bli_obj_is_double( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_double( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_dt( obj ) == BLIS_BITVAL_DOUBLE_TYPE ); } -static bool_t bli_obj_is_scomplex( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_scomplex( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_dt( obj ) == BLIS_BITVAL_SCOMPLEX_TYPE ); } -static bool_t bli_obj_is_dcomplex( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_dcomplex( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_dt( obj ) == BLIS_BITVAL_DCOMPLEX_TYPE ); } -static bool_t bli_obj_is_int( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_int( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_dt( obj ) == BLIS_BITVAL_INT_TYPE ); } -static bool_t bli_obj_is_const( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_const( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_dt( obj ) == BLIS_BITVAL_CONST_TYPE ); } -static dom_t bli_obj_domain( obj_t* obj ) +BLIS_INLINE dom_t bli_obj_domain( obj_t* obj ) { return ( dom_t ) ( obj->info & BLIS_DOMAIN_BIT ); } -static prec_t bli_obj_prec( obj_t* obj ) +BLIS_INLINE prec_t bli_obj_prec( obj_t* obj ) { return ( prec_t ) ( obj->info & BLIS_PRECISION_BIT ); } -static bool_t bli_obj_is_single_prec( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_single_prec( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_prec( obj ) == BLIS_BITVAL_SINGLE_PREC ); } -static bool_t bli_obj_is_double_prec( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_double_prec( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_prec( obj ) == BLIS_BITVAL_DOUBLE_PREC ); } -static num_t bli_obj_dt_proj_to_single_prec( obj_t* obj ) +BLIS_INLINE num_t bli_obj_dt_proj_to_single_prec( obj_t* obj ) { return ( num_t ) ( bli_obj_dt( obj ) & ~BLIS_BITVAL_SINGLE_PREC ); } -static num_t bli_obj_dt_proj_to_double_prec( obj_t* obj ) +BLIS_INLINE num_t bli_obj_dt_proj_to_double_prec( obj_t* obj ) { return ( num_t ) ( bli_obj_dt( obj ) | BLIS_BITVAL_DOUBLE_PREC ); } -static bool_t bli_obj_is_real( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_real( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_domain( obj ) == BLIS_BITVAL_REAL && !bli_obj_is_const( obj ) ); } -static bool_t bli_obj_is_complex( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_complex( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_domain( obj ) == BLIS_BITVAL_COMPLEX && !bli_obj_is_const( obj ) ); } -static num_t bli_obj_dt_proj_to_real( obj_t* obj ) +BLIS_INLINE num_t bli_obj_dt_proj_to_real( obj_t* obj ) { return ( num_t ) ( bli_obj_dt( obj ) & ~BLIS_BITVAL_COMPLEX ); } -static num_t bli_obj_dt_proj_to_complex( obj_t* obj ) +BLIS_INLINE num_t bli_obj_dt_proj_to_complex( obj_t* obj ) { return ( num_t ) ( bli_obj_dt( obj ) | BLIS_BITVAL_COMPLEX ); } -static num_t bli_obj_target_dt( obj_t* obj ) +BLIS_INLINE num_t bli_obj_target_dt( obj_t* obj ) { return ( num_t ) ( ( obj->info & BLIS_TARGET_DT_BITS ) >> BLIS_TARGET_DT_SHIFT ); } -static dom_t bli_obj_target_domain( obj_t* obj ) +BLIS_INLINE dom_t bli_obj_target_domain( obj_t* obj ) { return ( dom_t ) ( ( obj->info & BLIS_TARGET_DOMAIN_BIT ) >> BLIS_TARGET_DT_SHIFT ); } -static prec_t bli_obj_target_prec( obj_t* obj ) +BLIS_INLINE prec_t bli_obj_target_prec( obj_t* obj ) { return ( prec_t ) ( ( obj->info & BLIS_TARGET_PREC_BIT ) >> BLIS_TARGET_DT_SHIFT ); } -static num_t bli_obj_exec_dt( obj_t* obj ) +BLIS_INLINE num_t bli_obj_exec_dt( obj_t* obj ) { return ( num_t ) ( ( obj->info & BLIS_EXEC_DT_BITS ) >> BLIS_EXEC_DT_SHIFT ); } -static dom_t bli_obj_exec_domain( obj_t* obj ) +BLIS_INLINE dom_t bli_obj_exec_domain( obj_t* obj ) { return ( dom_t ) ( ( obj->info & BLIS_EXEC_DOMAIN_BIT ) >> BLIS_EXEC_DT_SHIFT ); } -static prec_t bli_obj_exec_prec( obj_t* obj ) +BLIS_INLINE prec_t bli_obj_exec_prec( obj_t* obj ) { return ( prec_t ) ( ( obj->info & BLIS_EXEC_PREC_BIT ) >> BLIS_EXEC_DT_SHIFT ); } -static num_t bli_obj_comp_dt( obj_t* obj ) +BLIS_INLINE num_t bli_obj_comp_dt( obj_t* obj ) { return ( num_t ) ( ( obj->info & BLIS_COMP_DT_BITS ) >> BLIS_COMP_DT_SHIFT ); } -static dom_t bli_obj_comp_domain( obj_t* obj ) +BLIS_INLINE dom_t bli_obj_comp_domain( obj_t* obj ) { return ( dom_t ) ( ( obj->info & BLIS_COMP_DOMAIN_BIT ) >> BLIS_COMP_DT_SHIFT ); } -static prec_t bli_obj_comp_prec( obj_t* obj ) +BLIS_INLINE prec_t bli_obj_comp_prec( obj_t* obj ) { return ( prec_t ) ( ( obj->info & BLIS_COMP_PREC_BIT ) >> BLIS_COMP_DT_SHIFT ); } // NOTE: This function queries info2. -static num_t bli_obj_scalar_dt( obj_t* obj ) +BLIS_INLINE num_t bli_obj_scalar_dt( obj_t* obj ) { return ( num_t ) ( ( obj->info2 & BLIS_SCALAR_DT_BITS ) >> BLIS_SCALAR_DT_SHIFT ); } // NOTE: This function queries info2. -static dom_t bli_obj_scalar_domain( obj_t* obj ) +BLIS_INLINE dom_t bli_obj_scalar_domain( obj_t* obj ) { return ( dom_t ) ( ( obj->info2 & BLIS_SCALAR_DOMAIN_BIT ) >> BLIS_SCALAR_DT_SHIFT ); } // NOTE: This function queries info2. -static prec_t bli_obj_scalar_prec( obj_t* obj ) +BLIS_INLINE prec_t bli_obj_scalar_prec( obj_t* obj ) { return ( prec_t ) ( ( obj->info2 & BLIS_SCALAR_PREC_BIT ) >> BLIS_SCALAR_DT_SHIFT ); } -static trans_t bli_obj_conjtrans_status( obj_t* obj ) +BLIS_INLINE trans_t bli_obj_conjtrans_status( obj_t* obj ) { return ( trans_t ) ( obj->info & BLIS_CONJTRANS_BITS ); } -static trans_t bli_obj_onlytrans_status( obj_t* obj ) +BLIS_INLINE trans_t bli_obj_onlytrans_status( obj_t* obj ) { return ( trans_t ) ( obj->info & BLIS_TRANS_BIT ); } -static bool_t bli_obj_has_trans( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_trans( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_onlytrans_status( obj ) == BLIS_BITVAL_TRANS ); } -static bool_t bli_obj_has_notrans( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_notrans( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_onlytrans_status( obj ) == BLIS_BITVAL_NO_TRANS ); } -static conj_t bli_obj_conj_status( obj_t* obj ) +BLIS_INLINE conj_t bli_obj_conj_status( obj_t* obj ) { return ( conj_t ) ( obj->info & BLIS_CONJ_BIT ); } -static bool_t bli_obj_has_conj( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_conj( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_conj_status( obj ) == BLIS_BITVAL_CONJ ); } -static bool_t bli_obj_has_noconj( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_noconj( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_conj_status( obj ) == BLIS_BITVAL_NO_CONJ ); } -static uplo_t bli_obj_uplo( obj_t* obj ) +BLIS_INLINE uplo_t bli_obj_uplo( obj_t* obj ) { return ( uplo_t ) ( obj->info & BLIS_UPLO_BITS ); } -static bool_t bli_obj_is_upper( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_upper( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_uplo( obj ) == BLIS_BITVAL_UPPER ); } -static bool_t bli_obj_is_lower( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_lower( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_uplo( obj ) == BLIS_BITVAL_LOWER ); } -static bool_t bli_obj_is_upper_or_lower( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_upper_or_lower( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_is_upper( obj ) || bli_obj_is_lower( obj ) ); } -static bool_t bli_obj_is_dense( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_dense( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_uplo( obj ) == BLIS_BITVAL_DENSE ); } -static bool_t bli_obj_is_zeros( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_zeros( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_uplo( obj ) == BLIS_BITVAL_ZEROS ); } -static diag_t bli_obj_diag( obj_t* obj ) +BLIS_INLINE diag_t bli_obj_diag( obj_t* obj ) { return ( diag_t ) ( obj->info & BLIS_UNIT_DIAG_BIT ); } -static bool_t bli_obj_has_nonunit_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_nonunit_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_diag( obj ) == BLIS_BITVAL_NONUNIT_DIAG ); } -static bool_t bli_obj_has_unit_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_unit_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_diag( obj ) == BLIS_BITVAL_UNIT_DIAG ); } -static bool_t bli_obj_has_inverted_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_has_inverted_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( ( obj->info & BLIS_INVERT_DIAG_BIT ) == BLIS_BITVAL_INVERT_DIAG ); } -static bool_t bli_obj_is_pack_rev_if_upper( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_pack_rev_if_upper( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( ( obj->info & BLIS_PACK_REV_IF_UPPER_BIT ) == BLIS_BITVAL_PACK_REV_IF_UPPER ); } -static bool_t bli_obj_is_pack_rev_if_lower( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_pack_rev_if_lower( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( ( obj->info & BLIS_PACK_REV_IF_LOWER_BIT ) == BLIS_BITVAL_PACK_REV_IF_LOWER ); } -static pack_t bli_obj_pack_schema( obj_t* obj ) +BLIS_INLINE pack_t bli_obj_pack_schema( obj_t* obj ) { return ( pack_t ) ( obj->info & BLIS_PACK_SCHEMA_BITS ); } -static bool_t bli_obj_is_packed( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_packed( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( obj->info & BLIS_PACK_BIT ); } -static bool_t bli_obj_is_row_packed( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_row_packed( obj_t* obj ) { - return ( bool_t ) - ( obj->info & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ - BLIS_BITVAL_PACKED_ROWS ); + return ( bool ) + ( ( obj->info & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ + BLIS_BITVAL_PACKED_ROWS ) ); } -static bool_t bli_obj_is_col_packed( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_col_packed( obj_t* obj ) { - return ( bool_t ) - ( obj->info & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ - BLIS_BITVAL_PACKED_COLUMNS ); + return ( bool ) + ( ( obj->info & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ + BLIS_BITVAL_PACKED_COLUMNS ) ); } -static bool_t bli_obj_is_panel_packed( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_panel_packed( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( obj->info & BLIS_PACK_PANEL_BIT ); } -static packbuf_t bli_obj_pack_buffer_type( obj_t* obj ) +BLIS_INLINE packbuf_t bli_obj_pack_buffer_type( obj_t* obj ) { return ( packbuf_t ) ( obj->info & BLIS_PACK_BUFFER_BITS ); } -static struc_t bli_obj_struc( obj_t* obj ) +BLIS_INLINE struc_t bli_obj_struc( obj_t* obj ) { return ( struc_t ) ( obj->info & BLIS_STRUC_BITS ); } -static bool_t bli_obj_is_general( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_general( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_struc( obj ) == BLIS_BITVAL_GENERAL ); } -static bool_t bli_obj_is_hermitian( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_hermitian( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_struc( obj ) == BLIS_BITVAL_HERMITIAN ); } -static bool_t bli_obj_is_symmetric( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_symmetric( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_struc( obj ) == BLIS_BITVAL_SYMMETRIC ); } -static bool_t bli_obj_is_triangular( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_triangular( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_struc( obj ) == BLIS_BITVAL_TRIANGULAR ); } // Info modification -static void bli_obj_apply_trans( trans_t trans, obj_t* obj ) +BLIS_INLINE void bli_obj_apply_trans( trans_t trans, obj_t* obj ) { obj->info = ( objbits_t ) ( obj->info ^ trans ); } -static void bli_obj_apply_conj( conj_t conj, obj_t* obj ) +BLIS_INLINE void bli_obj_apply_conj( conj_t conj, obj_t* obj ) { obj->info = ( objbits_t ) ( obj->info ^ conj ); } -static void bli_obj_set_conjtrans( trans_t trans, obj_t* obj ) +BLIS_INLINE void bli_obj_set_conjtrans( trans_t trans, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_CONJTRANS_BITS ) | trans; + ( ( obj->info & ~BLIS_CONJTRANS_BITS ) | trans ); } -static void bli_obj_set_onlytrans( trans_t trans, obj_t* obj ) +BLIS_INLINE void bli_obj_set_onlytrans( trans_t trans, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_TRANS_BIT ) | trans; + ( ( obj->info & ~BLIS_TRANS_BIT ) | trans ); } -static void bli_obj_set_conj( conj_t conj, obj_t* obj ) +BLIS_INLINE void bli_obj_set_conj( conj_t conj, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_CONJ_BIT ) | conj; + ( ( obj->info & ~BLIS_CONJ_BIT ) | conj ); } -static void bli_obj_set_uplo( uplo_t uplo, obj_t* obj ) +BLIS_INLINE void bli_obj_set_uplo( uplo_t uplo, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_UPLO_BITS ) | uplo; + ( ( obj->info & ~BLIS_UPLO_BITS ) | uplo ); } -static void bli_obj_set_diag( diag_t diag, obj_t* obj ) +BLIS_INLINE void bli_obj_set_diag( diag_t diag, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_UNIT_DIAG_BIT ) | diag; + ( ( obj->info & ~BLIS_UNIT_DIAG_BIT ) | diag ); } -static void bli_obj_set_invert_diag( invdiag_t invdiag, obj_t* obj ) +BLIS_INLINE void bli_obj_set_invert_diag( invdiag_t invdiag, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_INVERT_DIAG_BIT ) | invdiag; + ( ( obj->info & ~BLIS_INVERT_DIAG_BIT ) | invdiag ); } -static void bli_obj_set_dt( num_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_dt( num_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_DATATYPE_BITS ) | dt; + ( ( obj->info & ~BLIS_DATATYPE_BITS ) | dt ); } -static void bli_obj_set_target_dt( num_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_target_dt( num_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_TARGET_DT_BITS ) | - ( dt << BLIS_TARGET_DT_SHIFT ); + ( ( obj->info & ~BLIS_TARGET_DT_BITS ) | + ( dt << BLIS_TARGET_DT_SHIFT ) ); } -static void bli_obj_set_target_domain( dom_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_target_domain( dom_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_TARGET_DOMAIN_BIT ) | - ( dt << BLIS_TARGET_DT_SHIFT ); + ( ( obj->info & ~BLIS_TARGET_DOMAIN_BIT ) | + ( dt << BLIS_TARGET_DT_SHIFT ) ); } -static void bli_obj_set_target_prec( prec_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_target_prec( prec_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_TARGET_PREC_BIT ) | - ( dt << BLIS_TARGET_DT_SHIFT ); + ( ( obj->info & ~BLIS_TARGET_PREC_BIT ) | + ( dt << BLIS_TARGET_DT_SHIFT ) ); } -static void bli_obj_set_exec_dt( num_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_exec_dt( num_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_EXEC_DT_BITS ) | - ( dt << BLIS_EXEC_DT_SHIFT ); + ( ( obj->info & ~BLIS_EXEC_DT_BITS ) | + ( dt << BLIS_EXEC_DT_SHIFT ) ); } -static void bli_obj_set_exec_domain( dom_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_exec_domain( dom_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_EXEC_DOMAIN_BIT ) | - ( dt << BLIS_EXEC_DT_SHIFT ); + ( ( obj->info & ~BLIS_EXEC_DOMAIN_BIT ) | + ( dt << BLIS_EXEC_DT_SHIFT ) ); } -static void bli_obj_set_exec_prec( prec_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_exec_prec( prec_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_EXEC_PREC_BIT ) | - ( dt << BLIS_EXEC_DT_SHIFT ); + ( ( obj->info & ~BLIS_EXEC_PREC_BIT ) | + ( dt << BLIS_EXEC_DT_SHIFT ) ); } -static void bli_obj_set_comp_dt( num_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_comp_dt( num_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_COMP_DT_BITS ) | - ( dt << BLIS_COMP_DT_SHIFT ); + ( ( obj->info & ~BLIS_COMP_DT_BITS ) | + ( dt << BLIS_COMP_DT_SHIFT ) ); } -static void bli_obj_set_comp_domain( dom_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_comp_domain( dom_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_COMP_DOMAIN_BIT ) | - ( dt << BLIS_COMP_DT_SHIFT ); + ( ( obj->info & ~BLIS_COMP_DOMAIN_BIT ) | + ( dt << BLIS_COMP_DT_SHIFT ) ); } -static void bli_obj_set_comp_prec( prec_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_comp_prec( prec_t dt, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_COMP_PREC_BIT ) | - ( dt << BLIS_COMP_DT_SHIFT ); + ( ( obj->info & ~BLIS_COMP_PREC_BIT ) | + ( dt << BLIS_COMP_DT_SHIFT ) ); } // NOTE: This function queries and modifies info2. -static void bli_obj_set_scalar_dt( num_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_scalar_dt( num_t dt, obj_t* obj ) { obj->info2 = ( objbits_t ) - ( obj->info2 & ~BLIS_SCALAR_DT_BITS ) | - ( dt << BLIS_SCALAR_DT_SHIFT ); + ( ( obj->info2 & ~BLIS_SCALAR_DT_BITS ) | + ( dt << BLIS_SCALAR_DT_SHIFT ) ); } // NOTE: This function queries and modifies info2. -static void bli_obj_set_scalar_domain( dom_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_scalar_domain( dom_t dt, obj_t* obj ) { obj->info2 = ( objbits_t ) - ( obj->info2 & ~BLIS_SCALAR_DOMAIN_BIT ) | - ( dt << BLIS_SCALAR_DT_SHIFT ); + ( ( obj->info2 & ~BLIS_SCALAR_DOMAIN_BIT ) | + ( dt << BLIS_SCALAR_DT_SHIFT ) ); } // NOTE: This function queries and modifies info2. -static void bli_obj_set_scalar_prec( prec_t dt, obj_t* obj ) +BLIS_INLINE void bli_obj_set_scalar_prec( prec_t dt, obj_t* obj ) { obj->info2 = ( objbits_t ) - ( obj->info2 & ~BLIS_SCALAR_PREC_BIT ) | - ( dt << BLIS_SCALAR_DT_SHIFT ); + ( ( obj->info2 & ~BLIS_SCALAR_PREC_BIT ) | + ( dt << BLIS_SCALAR_DT_SHIFT ) ); } -static void bli_obj_set_pack_schema( pack_t schema, obj_t* obj ) +BLIS_INLINE void bli_obj_set_pack_schema( pack_t schema, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_PACK_SCHEMA_BITS ) | schema; + ( ( obj->info & ~BLIS_PACK_SCHEMA_BITS ) | schema ); } -static void bli_obj_set_pack_order_if_upper( packord_t ordif, obj_t* obj ) +BLIS_INLINE void bli_obj_set_pack_order_if_upper( packord_t ordif, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_PACK_REV_IF_UPPER_BIT ) | ordif; + ( ( obj->info & ~BLIS_PACK_REV_IF_UPPER_BIT ) | ordif ); } -static void bli_obj_set_pack_order_if_lower( packord_t ordif, obj_t* obj ) +BLIS_INLINE void bli_obj_set_pack_order_if_lower( packord_t ordif, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_PACK_REV_IF_LOWER_BIT ) | ordif; + ( ( obj->info & ~BLIS_PACK_REV_IF_LOWER_BIT ) | ordif ); } // NOTE: The packbuf_t bitfield in the obj_t is currently unused. Instead, // packbuf_t is stored/used from the context in order to support various // induced methods. (Though ideally the packbuf_t field would only be // present in the control tree). -static void bli_obj_set_pack_buffer_type( packbuf_t buf_type, obj_t* obj ) +BLIS_INLINE void bli_obj_set_pack_buffer_type( packbuf_t buf_type, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_PACK_BUFFER_BITS ) | buf_type; + ( ( obj->info & ~BLIS_PACK_BUFFER_BITS ) | buf_type ); } -static void bli_obj_set_struc( struc_t struc, obj_t* obj ) +BLIS_INLINE void bli_obj_set_struc( struc_t struc, obj_t* obj ) { obj->info = ( objbits_t ) - ( obj->info & ~BLIS_STRUC_BITS ) | struc; + ( ( obj->info & ~BLIS_STRUC_BITS ) | struc ); } -static void bli_obj_toggle_trans( obj_t* obj ) +BLIS_INLINE void bli_obj_toggle_trans( obj_t* obj ) { bli_obj_apply_trans( BLIS_TRANSPOSE, obj ); } -static void bli_obj_toggle_conj( obj_t* obj ) +BLIS_INLINE void bli_obj_toggle_conj( obj_t* obj ) { bli_obj_apply_conj( BLIS_CONJUGATE, obj ); } -static void bli_obj_toggle_uplo( obj_t* obj ) +BLIS_INLINE void bli_obj_toggle_uplo( obj_t* obj ) { obj->info = ( objbits_t ) ( obj->info ^ BLIS_LOWER_BIT ) ^ BLIS_UPPER_BIT; @@ -598,63 +599,70 @@ static void bli_obj_toggle_uplo( obj_t* obj ) // Root matrix query -static obj_t* bli_obj_root( obj_t* obj ) +BLIS_INLINE obj_t* bli_obj_root( obj_t* obj ) { - return ( obj->root ); + return ( obj_t* )( obj->root ); } -static bool_t bli_obj_root_is_general( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_general( obj_t* obj ) { - return bli_obj_is_general( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_general( bli_obj_root( obj ) ) ); } -static bool_t bli_obj_root_is_hermitian( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_hermitian( obj_t* obj ) { - return bli_obj_is_hermitian( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_hermitian( bli_obj_root( obj ) ) ); } -static bool_t bli_obj_root_is_symmetric( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_symmetric( obj_t* obj ) { - return bli_obj_is_symmetric( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_symmetric( bli_obj_root( obj ) ) ); } -static bool_t bli_obj_root_is_triangular( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_triangular( obj_t* obj ) { - return bli_obj_is_triangular( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_triangular( bli_obj_root( obj ) ) ); } -static bool_t bli_obj_root_is_herm_or_symm( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_herm_or_symm( obj_t* obj ) { - return bli_obj_is_hermitian( bli_obj_root( obj ) ) || - bli_obj_is_symmetric( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_hermitian( bli_obj_root( obj ) ) || + bli_obj_is_symmetric( bli_obj_root( obj ) ) ); } -static bool_t bli_obj_root_is_upper( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_upper( obj_t* obj ) { - return bli_obj_is_upper( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_upper( bli_obj_root( obj ) ) ); } -static bool_t bli_obj_root_is_lower( obj_t* obj ) +BLIS_INLINE bool bli_obj_root_is_lower( obj_t* obj ) { - return bli_obj_is_lower( bli_obj_root( obj ) ); + return ( bool ) + ( bli_obj_is_lower( bli_obj_root( obj ) ) ); } // Root matrix modification -static void bli_obj_set_as_root( obj_t* obj ) +BLIS_INLINE void bli_obj_set_as_root( obj_t* obj ) { obj->root = obj; } // Diagonal offset query -static doff_t bli_obj_diag_offset( obj_t* obj ) +BLIS_INLINE doff_t bli_obj_diag_offset( obj_t* obj ) { return ( doff_t ) ( obj->diag_off ); } -static doff_t bli_obj_diag_offset_after_trans( obj_t* obj ) +BLIS_INLINE doff_t bli_obj_diag_offset_after_trans( obj_t* obj ) { return ( doff_t ) ( bli_obj_has_trans( obj ) ? -bli_obj_diag_offset( obj ) @@ -663,106 +671,109 @@ static doff_t bli_obj_diag_offset_after_trans( obj_t* obj ) // Diagonal offset modification -static void bli_obj_set_diag_offset( doff_t offset, obj_t* obj ) +BLIS_INLINE void bli_obj_set_diag_offset( doff_t offset, obj_t* obj ) { obj->diag_off = ( doff_t )offset; } -static void bli_obj_negate_diag_offset( obj_t* obj ) +BLIS_INLINE void bli_obj_negate_diag_offset( obj_t* obj ) { obj->diag_off = -(obj->diag_off); } -static void bli_obj_inc_diag_offset( doff_t offset, obj_t* obj ) +BLIS_INLINE void bli_obj_inc_diag_offset( doff_t offset, obj_t* obj ) { obj->diag_off += ( doff_t )offset; } // Dimension query -static dim_t bli_obj_length( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_length( obj_t* obj ) { return ( obj->dim[ BLIS_M ] ); } -static dim_t bli_obj_width( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_width( obj_t* obj ) { return ( obj->dim[ BLIS_N ] ); } -static dim_t bli_obj_dim( mdim_t mdim, obj_t* obj ) +BLIS_INLINE dim_t bli_obj_dim( mdim_t mdim, obj_t* obj ) { return ( obj->dim[ mdim ] ); } -static dim_t bli_obj_min_dim( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_min_dim( obj_t* obj ) { return bli_min( bli_obj_length( obj ), - bli_obj_width( obj ) ); + bli_obj_width( obj ) ); } -static dim_t bli_obj_max_dim( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_max_dim( obj_t* obj ) { return bli_max( bli_obj_length( obj ), - bli_obj_width( obj ) ); + bli_obj_width( obj ) ); } -static dim_t bli_obj_length_after_trans( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_length_after_trans( obj_t* obj ) { - return ( bli_obj_has_trans( obj ) ? bli_obj_width( obj ) + return ( bli_obj_has_trans( obj ) ? bli_obj_width( obj ) : bli_obj_length( obj ) ); } -static dim_t bli_obj_width_after_trans( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_width_after_trans( obj_t* obj ) { return ( bli_obj_has_trans( obj ) ? bli_obj_length( obj ) - : bli_obj_width( obj ) ); + : bli_obj_width( obj ) ); } -static bool_t bli_obj_is_1x1( obj_t* x ) +BLIS_INLINE bool bli_obj_is_1x1( obj_t* x ) { - return ( bool_t ) + return ( bool ) ( bli_obj_length( x ) == 1 && - bli_obj_width( x ) == 1 ); + bli_obj_width( x ) == 1 ); } // Stride/increment query -static inc_t bli_obj_row_stride( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_row_stride( obj_t* obj ) { return ( obj->rs ); } -static inc_t bli_obj_col_stride( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_col_stride( obj_t* obj ) { return ( obj->cs ); } -static inc_t bli_obj_imag_stride( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_imag_stride( obj_t* obj ) { return ( obj->is ); } -static inc_t bli_obj_row_stride_mag( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_row_stride_mag( obj_t* obj ) { - return ( bli_abs( obj->rs ) ); + return ( inc_t ) + ( bli_abs( obj->rs ) ); } -static inc_t bli_obj_col_stride_mag( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_col_stride_mag( obj_t* obj ) { - return ( bli_abs( obj->cs ) ); + return ( inc_t ) + ( bli_abs( obj->cs ) ); } -static inc_t bli_obj_imag_stride_mag( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_imag_stride_mag( obj_t* obj ) { - return ( bli_abs( obj->is ) ); + return ( inc_t ) + ( bli_abs( obj->is ) ); } // Note: The purpose of these functions is to obtain the length and width // of the smallest submatrices of an object that could still encompass // the stored data above (if obj is upper) or below (if obj is lower) // the diagonal. -static dim_t bli_obj_length_stored( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_length_stored( obj_t* obj ) { return ( dim_t ) ( bli_obj_is_upper( obj ) @@ -773,7 +784,7 @@ static dim_t bli_obj_length_stored( obj_t* obj ) ); } -static dim_t bli_obj_width_stored( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_width_stored( obj_t* obj ) { return ( dim_t ) ( bli_obj_is_lower( obj ) @@ -784,90 +795,89 @@ static dim_t bli_obj_width_stored( obj_t* obj ) ); } -static dim_t bli_obj_length_stored_after_trans( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_length_stored_after_trans( obj_t* obj ) { - return ( bli_obj_has_trans( obj ) ? bli_obj_width_stored( obj ) + return ( bli_obj_has_trans( obj ) ? bli_obj_width_stored( obj ) : bli_obj_length_stored( obj ) ); } -static dim_t bli_obj_width_stored_after_trans( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_width_stored_after_trans( obj_t* obj ) { return ( bli_obj_has_trans( obj ) ? bli_obj_length_stored( obj ) - : bli_obj_width_stored( obj ) ); + : bli_obj_width_stored( obj ) ); } -static dim_t bli_obj_vector_dim( obj_t* x ) +BLIS_INLINE dim_t bli_obj_vector_dim( obj_t* x ) { - return ( bli_obj_length( x ) == 1 ? bli_obj_width( x ) + return ( bli_obj_length( x ) == 1 ? bli_obj_width( x ) : bli_obj_length( x ) ); } -static inc_t bli_obj_vector_inc( obj_t* x ) +BLIS_INLINE inc_t bli_obj_vector_inc( obj_t* x ) { - return ( bli_obj_is_1x1( x ) ? 1 : \ + return ( bli_obj_is_1x1( x ) ? 1 : ( bli_obj_length( x ) == 1 ? bli_obj_col_stride( x ) : bli_obj_row_stride( x ) ) ); } -static bool_t bli_obj_is_vector( obj_t* x ) +BLIS_INLINE bool bli_obj_is_vector( obj_t* x ) { - return ( bool_t ) + return ( bool ) ( bli_obj_length( x ) == 1 || - bli_obj_width( x ) == 1 ); + bli_obj_width( x ) == 1 ); } -static bool_t bli_obj_is_row_vector( obj_t* x ) +BLIS_INLINE bool bli_obj_is_row_vector( obj_t* x ) { - return ( bool_t ) + return ( bool ) ( bli_obj_length( x ) == 1 ); } -static bool_t bli_obj_is_col_vector( obj_t* x ) +BLIS_INLINE bool bli_obj_is_col_vector( obj_t* x ) { - return ( bool_t ) + return ( bool ) ( bli_obj_width( x ) == 1 ); } -static bool_t bli_obj_has_zero_dim( obj_t* x ) +BLIS_INLINE bool bli_obj_has_zero_dim( obj_t* x ) { - return ( bool_t ) + return ( bool ) ( bli_obj_length( x ) == 0 || - bli_obj_width( x ) == 0 ); + bli_obj_width( x ) == 0 ); } // Dimension modification -static void bli_obj_set_length( dim_t m, obj_t* obj ) +BLIS_INLINE void bli_obj_set_length( dim_t m, obj_t* obj ) { obj->dim[ BLIS_M ] = m; } -static void bli_obj_set_width( dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_width( dim_t n, obj_t* obj ) { obj->dim[ BLIS_N ] = n; } -static void bli_obj_set_dim( mdim_t mdim, dim_t dim_val, obj_t* obj ) +BLIS_INLINE void bli_obj_set_dim( mdim_t mdim, dim_t dim_val, obj_t* obj ) { obj->dim[ mdim ] = dim_val; } -static void bli_obj_set_dims( dim_t m, dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_dims( dim_t m, dim_t n, obj_t* obj ) { bli_obj_set_length( m, obj ); bli_obj_set_width( n, obj ); } -static void bli_obj_set_dims_with_trans( trans_t trans, dim_t m, dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_dims_with_trans( trans_t trans, dim_t m, dim_t n, obj_t* obj ) { - //if ( bli_does_notrans( trans ) ) - if ( ( ~trans & BLIS_TRANS_BIT ) == BLIS_BITVAL_TRANS ) + if ( bli_does_notrans( trans ) ) { bli_obj_set_length( m, obj ); bli_obj_set_width( n, obj ); } - else + else // if ( bli_does_trans( trans ) ) { bli_obj_set_length( n, obj ); bli_obj_set_width( m, obj ); @@ -884,86 +894,96 @@ static void bli_obj_set_dims_with_trans( trans_t trans, dim_t m, dim_t n, obj_t* // "obj" macros are used on packed matrices. // -static bool_t bli_obj_is_row_stored( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_row_stored( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_col_stride_mag( obj ) == 1 ); } -static bool_t bli_obj_is_col_stored( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_col_stored( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_row_stride_mag( obj ) == 1 ); } -static bool_t bli_obj_is_gen_stored( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_gen_stored( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_row_stride_mag( obj ) != 1 && bli_obj_col_stride_mag( obj ) != 1 ); } -static bool_t bli_obj_is_row_tilted( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_row_tilted( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_col_stride_mag( obj ) < bli_obj_row_stride_mag( obj ) ); } -static bool_t bli_obj_is_col_tilted( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_col_tilted( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_row_stride_mag( obj ) < bli_obj_col_stride_mag( obj ) ); } // Stride/increment modification -static void bli_obj_set_strides( inc_t rs, inc_t cs, obj_t* obj ) +BLIS_INLINE void bli_obj_set_row_stride( inc_t rs, obj_t* obj ) { obj->rs = rs; +} + +BLIS_INLINE void bli_obj_set_col_stride( inc_t cs, obj_t* obj ) +{ obj->cs = cs; } -static void bli_obj_set_imag_stride( inc_t is, obj_t* obj ) +BLIS_INLINE void bli_obj_set_strides( inc_t rs, inc_t cs, obj_t* obj ) +{ + bli_obj_set_row_stride( rs, obj ); + bli_obj_set_col_stride( cs, obj ); +} + +BLIS_INLINE void bli_obj_set_imag_stride( inc_t is, obj_t* obj ) { obj->is = is; } // Offset query -static dim_t bli_obj_row_off( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_row_off( obj_t* obj ) { return ( obj->off[ BLIS_M ] ); } -static dim_t bli_obj_col_off( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_col_off( obj_t* obj ) { return ( obj->off[ BLIS_N ] ); } -static dim_t bli_obj_off( mdim_t mdim, obj_t* obj ) +BLIS_INLINE dim_t bli_obj_off( mdim_t mdim, obj_t* obj ) { return ( obj->off[ mdim ] ); } // Offset modification -static void bli_obj_set_off( mdim_t mdim, dim_t offset, obj_t* obj ) +BLIS_INLINE void bli_obj_set_off( mdim_t mdim, dim_t offset, obj_t* obj ) { obj->off[ mdim ] = offset; } -static void bli_obj_set_offs( dim_t offm, dim_t offn, obj_t* obj ) +BLIS_INLINE void bli_obj_set_offs( dim_t offm, dim_t offn, obj_t* obj ) { bli_obj_set_off( BLIS_M, offm, obj ); bli_obj_set_off( BLIS_N, offn, obj ); } -static void bli_obj_inc_off( mdim_t mdim, dim_t offset, obj_t* obj ) +BLIS_INLINE void bli_obj_inc_off( mdim_t mdim, dim_t offset, obj_t* obj ) { obj->off[ mdim ] += offset; } -static void bli_obj_inc_offs( dim_t offm, dim_t offn, obj_t* obj ) +BLIS_INLINE void bli_obj_inc_offs( dim_t offm, dim_t offn, obj_t* obj ) { bli_obj_inc_off( BLIS_M, offm, obj ); bli_obj_inc_off( BLIS_N, offn, obj ); @@ -971,56 +991,57 @@ static void bli_obj_inc_offs( dim_t offm, dim_t offn, obj_t* obj ) // Diagonal offset predicates -static bool_t bli_obj_is_strictly_above_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_strictly_above_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( ( doff_t )bli_obj_length( obj ) <= -bli_obj_diag_offset( obj ) ); } -static bool_t bli_obj_is_strictly_below_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_strictly_below_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( ( doff_t )bli_obj_width( obj ) <= bli_obj_diag_offset( obj ) ); } -static bool_t bli_obj_is_outside_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_outside_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( bli_obj_is_strictly_above_diag( obj ) || bli_obj_is_strictly_below_diag( obj ) ); } -static bool_t bli_obj_intersects_diag( obj_t* obj ) +BLIS_INLINE bool bli_obj_intersects_diag( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( !bli_obj_is_strictly_above_diag( obj ) && !bli_obj_is_strictly_below_diag( obj ) ); } -static bool_t bli_obj_is_unstored_subpart( obj_t* obj ) +BLIS_INLINE bool bli_obj_is_unstored_subpart( obj_t* obj ) { - return ( bool_t ) + return ( bool ) ( ( bli_obj_root_is_lower( obj ) && bli_obj_is_strictly_above_diag( obj ) ) || ( bli_obj_root_is_upper( obj ) && bli_obj_is_strictly_below_diag( obj ) ) ); } // Buffer address query -static void* bli_obj_buffer( obj_t* obj ) +BLIS_INLINE void* bli_obj_buffer( obj_t* obj ) { - return ( obj->buffer ); + return ( void* ) + ( obj->buffer ); } // Buffer address modification -static void bli_obj_set_buffer( void* p, obj_t* obj ) +BLIS_INLINE void bli_obj_set_buffer( void* p, obj_t* obj ) { obj->buffer = p; } // Bufferless scalar field query -static void* bli_obj_internal_scalar_buffer( obj_t* obj ) +BLIS_INLINE void* bli_obj_internal_scalar_buffer( obj_t* obj ) { return ( void* ) ( &( obj->scalar ) ); @@ -1028,50 +1049,51 @@ static void* bli_obj_internal_scalar_buffer( obj_t* obj ) // Bufferless scalar field modification -static void bli_obj_copy_internal_scalar( obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_copy_internal_scalar( obj_t* a, obj_t* b ) { b->scalar = a->scalar; } // Element size query -static siz_t bli_obj_elem_size( obj_t* obj ) +BLIS_INLINE siz_t bli_obj_elem_size( obj_t* obj ) { - return ( obj->elem_size ); + return ( siz_t ) + ( obj->elem_size ); } // Element size modification -static void bli_obj_set_elem_size( siz_t size, obj_t* obj ) +BLIS_INLINE void bli_obj_set_elem_size( siz_t size, obj_t* obj ) { obj->elem_size = size; } // Packed matrix info query -static dim_t bli_obj_padded_length( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_padded_length( obj_t* obj ) { return ( obj->m_padded ); } -static dim_t bli_obj_padded_width( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_padded_width( obj_t* obj ) { return ( obj->n_padded ); } // Packed matrix info modification -static void bli_obj_set_padded_length( dim_t m, obj_t* obj ) +BLIS_INLINE void bli_obj_set_padded_length( dim_t m, obj_t* obj ) { obj->m_padded = m; } -static void bli_obj_set_padded_width( dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_padded_width( dim_t n, obj_t* obj ) { obj->n_padded = n; } -static void bli_obj_set_padded_dims( dim_t m, dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_padded_dims( dim_t m, dim_t n, obj_t* obj ) { bli_obj_set_padded_length( m, obj ); bli_obj_set_padded_width( n, obj ); @@ -1079,59 +1101,193 @@ static void bli_obj_set_padded_dims( dim_t m, dim_t n, obj_t* obj ) // Packed panel info query -static dim_t bli_obj_panel_length( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_panel_length( obj_t* obj ) { return ( obj->m_panel ); } -static dim_t bli_obj_panel_width( obj_t* obj ) +BLIS_INLINE dim_t bli_obj_panel_width( obj_t* obj ) { return ( obj->n_panel ); } -static inc_t bli_obj_panel_dim( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_panel_dim( obj_t* obj ) { return ( obj->pd ); } -static inc_t bli_obj_panel_stride( obj_t* obj ) +BLIS_INLINE inc_t bli_obj_panel_stride( obj_t* obj ) { return ( obj->ps ); } // Packed panel info modification -static void bli_obj_set_panel_length( dim_t m, obj_t* obj ) +BLIS_INLINE void bli_obj_set_panel_length( dim_t m, obj_t* obj ) { obj->m_panel = m; } -static void bli_obj_set_panel_width( dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_panel_width( dim_t n, obj_t* obj ) { obj->n_panel = n; } -static void bli_obj_set_panel_dims( dim_t m, dim_t n, obj_t* obj ) +BLIS_INLINE void bli_obj_set_panel_dims( dim_t m, dim_t n, obj_t* obj ) { bli_obj_set_panel_length( m, obj ); bli_obj_set_panel_width( n, obj ); } -static void bli_obj_set_panel_dim( inc_t pd, obj_t* obj ) +BLIS_INLINE void bli_obj_set_panel_dim( inc_t pd, obj_t* obj ) { obj->pd = pd; } -static void bli_obj_set_panel_stride( inc_t ps, obj_t* obj ) +BLIS_INLINE void bli_obj_set_panel_stride( inc_t ps, obj_t* obj ) { obj->ps = ps; } +// stor3_t-related + +BLIS_INLINE stor3_t bli_obj_stor3_from_strides( obj_t* c, obj_t* a, obj_t* b ) +{ + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + inc_t rs_a, cs_a; + inc_t rs_b, cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else + { + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else + { + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + return bli_stor3_from_strides( rs_c, cs_c, + rs_a, cs_a, + rs_b, cs_b ); +} + + +// -- User-provided information macros -- + +// Function pointer query + +BLIS_INLINE obj_pack_fn_t bli_obj_pack_fn( obj_t* obj ) +{ + return obj->pack_fn; +} + +BLIS_INLINE void* bli_obj_pack_params( obj_t* obj ) +{ + return obj->pack_params; +} + +BLIS_INLINE obj_ker_fn_t bli_obj_ker_fn( obj_t* obj ) +{ + return obj->ker_fn; +} + +BLIS_INLINE void* bli_obj_ker_params( obj_t* obj ) +{ + return obj->ker_params; +} + +// Function pointer modification + +BLIS_INLINE void bli_obj_set_pack_fn( obj_pack_fn_t pack_fn, obj_t* obj ) +{ + obj->pack_fn = pack_fn; +} + +BLIS_INLINE void bli_obj_set_pack_params( void* params, obj_t* obj ) +{ + obj->pack_params = params; +} + +BLIS_INLINE void bli_obj_set_ker_fn( obj_ker_fn_t ker_fn, obj_t* obj ) +{ + obj->ker_fn = ker_fn; +} + +BLIS_INLINE void bli_obj_set_ker_params( void* params, obj_t* obj ) +{ + obj->ker_params = params; +} + + +// -- Initialization-related macros -- + +// Finish the initialization started by the matrix-specific static initializer +// (e.g. BLIS_OBJECT_INITIALIZER) +// NOTE: This is intended only for use in the BLAS compatibility API and typed +// BLIS API. + +BLIS_INLINE void bli_obj_init_finish( num_t dt, dim_t m, dim_t n, void* p, inc_t rs, inc_t cs, obj_t* obj ) +{ + bli_obj_set_as_root( obj ); + + bli_obj_set_dt( dt, obj ); + bli_obj_set_target_dt( dt, obj ); + bli_obj_set_exec_dt( dt, obj ); + bli_obj_set_comp_dt( dt, obj ); + + bli_obj_set_dims( m, n, obj ); + bli_obj_set_strides( rs, cs, obj ); + + siz_t elem_size = sizeof( float ); + if ( bli_dt_prec_is_double( dt ) ) elem_size *= 2; + if ( bli_dt_dom_is_complex( dt ) ) elem_size *= 2; + bli_obj_set_elem_size( elem_size, obj ); + + bli_obj_set_buffer( p, obj ); + + bli_obj_set_scalar_dt( dt, obj ); + void* restrict s = bli_obj_internal_scalar_buffer( obj ); + + if ( bli_dt_prec_is_single( dt ) ) { (( scomplex* )s)->real = 1.0F; + (( scomplex* )s)->imag = 0.0F; } + else if ( bli_dt_prec_is_double( dt ) ) { (( dcomplex* )s)->real = 1.0; + (( dcomplex* )s)->imag = 0.0; } +} + +// Finish the initialization started by the 1x1-specific static initializer +// (e.g. BLIS_OBJECT_INITIALIZER_1X1) +// NOTE: This is intended only for use in the BLAS compatibility API and typed +// BLIS API. + +BLIS_INLINE void bli_obj_init_finish_1x1( num_t dt, void* p, obj_t* obj ) +{ + bli_obj_set_as_root( obj ); + + bli_obj_set_dt( dt, obj ); + + bli_obj_set_buffer( p, obj ); +} + // -- Miscellaneous object macros -- // Toggle the region referenced (or "stored"). -static void bli_obj_toggle_region_ref( obj_t* obj ) +BLIS_INLINE void bli_obj_toggle_region_ref( obj_t* obj ) { if ( bli_obj_is_upper( obj ) ) bli_obj_inc_diag_offset( -1, obj ); else if ( bli_obj_is_lower( obj ) ) bli_obj_inc_diag_offset( 1, obj ); @@ -1139,10 +1295,9 @@ static void bli_obj_toggle_region_ref( obj_t* obj ) bli_obj_toggle_uplo( obj ); } -static void bli_obj_toggle_uplo_if_trans( trans_t trans, obj_t* obj ) +BLIS_INLINE void bli_obj_toggle_uplo_if_trans( trans_t trans, obj_t* obj ) { - //if ( bli_does_trans( trans ) && - if ( ( trans & BLIS_TRANS_BIT ) == BLIS_BITVAL_TRANS && + if ( bli_does_trans( trans ) && bli_obj_is_upper_or_lower( obj ) ) { bli_obj_toggle_uplo( obj ); @@ -1152,47 +1307,15 @@ static void bli_obj_toggle_uplo_if_trans( trans_t trans, obj_t* obj ) // Initialize object with default properties (info field). -static void bli_obj_set_defaults( obj_t* obj ) +BLIS_INLINE void bli_obj_set_defaults( obj_t* obj ) { obj->info = 0x0; obj->info = obj->info | BLIS_BITVAL_DENSE | BLIS_BITVAL_GENERAL; } -// Initializors for global scalar constants. -// NOTE: These must remain cpp macros since they are initializor -// expressions, not functions. - -#define bli_obj_init_const( buffer0 ) \ -{ \ - .root = NULL, \ -\ - .off = { 0, 0 }, \ - .dim = { 1, 1 }, \ - .diag_off = 0, \ -\ - .info = 0x0 | BLIS_BITVAL_CONST_TYPE | \ - BLIS_BITVAL_DENSE | \ - BLIS_BITVAL_GENERAL, \ - .elem_size = sizeof( constdata_t ), \ -\ - .buffer = buffer0, \ - .rs = 1, \ - .cs = 1, \ - .is = 1 \ -} - -#define bli_obj_init_constdata( val ) \ -{ \ - .s = ( float )val, \ - .d = ( double )val, \ - .c = { .real = ( float )val, .imag = 0.0f }, \ - .z = { .real = ( double )val, .imag = 0.0 }, \ - .i = ( gint_t )val, \ -} - // Acquire buffer at object's submatrix offset (offset-aware buffer query). -static void* bli_obj_buffer_at_off( obj_t* obj ) +BLIS_INLINE void* bli_obj_buffer_at_off( obj_t* obj ) { return ( void* ) ( @@ -1207,7 +1330,7 @@ static void* bli_obj_buffer_at_off( obj_t* obj ) // Acquire buffer from BLIS_CONSTANT object. -static void* bli_obj_buffer_for_const( num_t dt, obj_t* obj ) +BLIS_INLINE void* bli_obj_buffer_for_const( num_t dt, obj_t* obj ) { void* p; @@ -1222,7 +1345,7 @@ static void* bli_obj_buffer_for_const( num_t dt, obj_t* obj ) // Acquire buffer from scalar (1x1) object, including BLIS_CONSTANT objects. -static void* bli_obj_buffer_for_1x1( num_t dt, obj_t* obj ) +BLIS_INLINE void* bli_obj_buffer_for_1x1( num_t dt, obj_t* obj ) { return ( void* ) ( bli_obj_is_const( obj ) ? bli_obj_buffer_for_const( dt, obj ) @@ -1230,18 +1353,30 @@ static void* bli_obj_buffer_for_1x1( num_t dt, obj_t* obj ) ); } +// Adjust the pointer based on current offsets, zero the offsets, and then +// set the current object as the root. For obj_t's with at least one non-zero +// offset, this effectively makes the obj_t "forget" that it was ever a view +// into a larger matrix. + +BLIS_INLINE void bli_obj_reset_origin( obj_t* obj ) +{ + bli_obj_set_buffer( bli_obj_buffer_at_off( obj ), obj ); + bli_obj_set_offs( 0, 0, obj ); + bli_obj_set_as_root( obj ); +} + // Make a full alias (shallow copy). -static void bli_obj_alias_to( obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_alias_to( obj_t* a, obj_t* b ) { bli_obj_init_full_shallow_copy_of( a, b ); } // Check if two objects are aliases of one another. -static bool_t bli_obj_is_alias_of( obj_t* a, obj_t* b ) +BLIS_INLINE bool bli_obj_is_alias_of( obj_t* a, obj_t* b ) { - return ( bool_t ) + return ( bool ) ( bli_obj_buffer( a ) == bli_obj_buffer( b ) ); } @@ -1249,7 +1384,7 @@ static bool_t bli_obj_is_alias_of( obj_t* a, obj_t* b ) // Create an alias with a trans value applied. // (Note: trans may include a conj component.) -static void bli_obj_alias_with_trans( trans_t trans, obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_alias_with_trans( trans_t trans, obj_t* a, obj_t* b ) { bli_obj_alias_to( a, b ); bli_obj_apply_trans( trans, b ); @@ -1257,7 +1392,7 @@ static void bli_obj_alias_with_trans( trans_t trans, obj_t* a, obj_t* b ) // Create an alias with a conj value applied. -static void bli_obj_alias_with_conj( conj_t conja, obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_alias_with_conj( conj_t conja, obj_t* a, obj_t* b ) { bli_obj_alias_to( a, b ); bli_obj_apply_conj( conja, b ); @@ -1265,7 +1400,7 @@ static void bli_obj_alias_with_conj( conj_t conja, obj_t* a, obj_t* b ) // Alias only the real part. -static void bli_obj_real_part( obj_t* c, obj_t* r ) +BLIS_INLINE void bli_obj_real_part( obj_t* c, obj_t* r ) { bli_obj_alias_to( c, r ); @@ -1298,7 +1433,7 @@ static void bli_obj_real_part( obj_t* c, obj_t* r ) // Alias only the imaginary part. -static void bli_obj_imag_part( obj_t* c, obj_t* i ) +BLIS_INLINE void bli_obj_imag_part( obj_t* c, obj_t* i ) { if ( bli_obj_is_complex( c ) ) { @@ -1337,7 +1472,7 @@ static void bli_obj_imag_part( obj_t* c, obj_t* i ) // chosen buffer (possibly using an auxiliary datatype if the object is // BLIS_CONSTANT). -static void bli_obj_scalar_set_dt_buffer( obj_t* obj, num_t dt_aux, num_t* dt, void** buf ) +BLIS_INLINE void bli_obj_scalar_set_dt_buffer( obj_t* obj, num_t dt_aux, num_t* dt, void** buf ) { if ( bli_obj_is_const( obj ) ) { @@ -1353,14 +1488,20 @@ static void bli_obj_scalar_set_dt_buffer( obj_t* obj, num_t dt_aux, num_t* dt, v // Swap all object fields (metadata/properties). -static void bli_obj_swap( obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_swap( obj_t* a, obj_t* b ) { + bool a_root_is_self = ( bli_obj_root( a ) == a ); + bool b_root_is_self = ( bli_obj_root( b ) == b ); + obj_t t = *b; *b = *a; *a = t; + + if ( a_root_is_self ) bli_obj_set_as_root( b ); + if ( b_root_is_self ) bli_obj_set_as_root( a ); } // Swap object pack schemas. -static void bli_obj_swap_pack_schemas( obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_swap_pack_schemas( obj_t* a, obj_t* b ) { const pack_t schema_a = bli_obj_pack_schema( a ); const pack_t schema_b = bli_obj_pack_schema( b ); @@ -1372,7 +1513,7 @@ static void bli_obj_swap_pack_schemas( obj_t* a, obj_t* b ) // Induce a transposition on an object: swap dimensions, increments, and // offsets, then clear the trans bit. -static void bli_obj_induce_trans( obj_t* obj ) +BLIS_INLINE void bli_obj_induce_trans( obj_t* obj ) { // Induce transposition among basic fields. dim_t m = bli_obj_length( obj ); @@ -1401,7 +1542,32 @@ static void bli_obj_induce_trans( obj_t* obj ) bli_obj_set_panel_dims( n_panel, m_panel, obj ); // Note that this macro DOES NOT touch the transposition bit! If - // the calling code is using this macro to handle an object whose + // the calling code is using this function to handle an object whose + // transposition bit is set prior to computation, that code needs + // to manually clear or toggle the bit, via + // bli_obj_set_onlytrans() or bli_obj_toggle_trans(), + // respectively. +} + +BLIS_INLINE void bli_obj_induce_fast_trans( obj_t* obj ) +{ + // NOTE: This function is only used in situations where the matrices + // are guaranteed to not have structure or be packed. + + // Induce transposition among basic fields. + dim_t m = bli_obj_length( obj ); + dim_t n = bli_obj_width( obj ); + inc_t rs = bli_obj_row_stride( obj ); + inc_t cs = bli_obj_col_stride( obj ); + dim_t offm = bli_obj_row_off( obj ); + dim_t offn = bli_obj_col_off( obj ); + + bli_obj_set_dims( n, m, obj ); + bli_obj_set_strides( cs, rs, obj ); + bli_obj_set_offs( offn, offm, obj ); + + // Note that this macro DOES NOT touch the transposition bit! If + // the calling code is using this function to handle an object whose // transposition bit is set prior to computation, that code needs // to manually clear or toggle the bit, via // bli_obj_set_onlytrans() or bli_obj_toggle_trans(), @@ -1414,7 +1580,7 @@ static void bli_obj_induce_trans( obj_t* obj ) // and column strides are left unchanged (which, of course, drastically // changes the effect of the macro). -static void bli_obj_reflect_about_diag( obj_t* obj ) +BLIS_INLINE void bli_obj_reflect_about_diag( obj_t* obj ) { dim_t m = bli_obj_length( obj ); dim_t n = bli_obj_width( obj ); diff --git a/frame/include/bli_param_macro_defs.h b/frame/include/bli_param_macro_defs.h index f4e7e775f3..286e79e2b7 100644 --- a/frame/include/bli_param_macro_defs.h +++ b/frame/include/bli_param_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,19 +41,19 @@ // buffer -static bool_t bli_is_aligned_to( siz_t p, siz_t size ) +BLIS_INLINE bool bli_is_aligned_to( siz_t p, siz_t size ) { - return ( bool_t ) + return ( bool ) ( p % size == 0 ); } -static bool_t bli_is_unaligned_to( siz_t p, siz_t size ) +BLIS_INLINE bool bli_is_unaligned_to( siz_t p, siz_t size ) { - return ( bool_t ) + return ( bool ) ( p % size != 0 ); } -static siz_t bli_offset_past_alignment( siz_t p, siz_t size ) +BLIS_INLINE siz_t bli_offset_past_alignment( siz_t p, siz_t size ) { return ( siz_t ) ( p % size ); @@ -62,101 +62,125 @@ static siz_t bli_offset_past_alignment( siz_t p, siz_t size ) // datatype -static bool_t bli_is_float( num_t dt ) +BLIS_INLINE bool bli_is_float( num_t dt ) { - return ( bool_t ) + return ( bool ) ( dt == BLIS_FLOAT ); } -static bool_t bli_is_double( num_t dt ) +BLIS_INLINE bool bli_is_double( num_t dt ) { - return ( bool_t ) + return ( bool ) ( dt == BLIS_DOUBLE ); } -static bool_t bli_is_scomplex( num_t dt ) +BLIS_INLINE bool bli_is_scomplex( num_t dt ) { - return ( bool_t ) + return ( bool ) ( dt == BLIS_SCOMPLEX ); } -static bool_t bli_is_dcomplex( num_t dt ) +BLIS_INLINE bool bli_is_dcomplex( num_t dt ) { - return ( bool_t ) + return ( bool ) ( dt == BLIS_DCOMPLEX ); } -static bool_t bli_is_constant( num_t dt ) +BLIS_INLINE bool bli_is_constant( num_t dt ) { - return ( bool_t ) + return ( bool ) ( dt == BLIS_CONSTANT ); } -static bool_t bli_is_int( num_t dt ) +BLIS_INLINE bool bli_is_int( num_t dt ) { - return ( bool_t ) + return ( bool ) ( dt == BLIS_INT ); } -static bool_t bli_is_real( num_t dt ) +BLIS_INLINE bool bli_is_real( num_t dt ) { - return ( bool_t ) + return ( bool ) ( bli_is_float( dt ) || bli_is_double( dt ) ); } -static bool_t bli_is_complex( num_t dt ) +BLIS_INLINE bool bli_is_complex( num_t dt ) { - return ( bool_t ) + return ( bool ) ( bli_is_scomplex( dt ) || bli_is_dcomplex( dt ) ); } -static bool_t bli_is_single_prec( num_t dt ) +BLIS_INLINE bool bli_is_single_prec( num_t dt ) { - return ( bool_t ) + return ( bool ) ( bli_is_float( dt ) || bli_is_scomplex( dt ) ); } -static bool_t bli_is_double_prec( num_t dt ) +BLIS_INLINE bool bli_is_double_prec( num_t dt ) { - return ( bool_t ) + return ( bool ) ( bli_is_double( dt ) || bli_is_dcomplex( dt ) ); } -static dom_t bli_dt_domain( num_t dt ) +BLIS_INLINE dom_t bli_dt_domain( num_t dt ) { return ( dom_t ) ( dt & BLIS_DOMAIN_BIT ); } -static prec_t bli_dt_prec( num_t dt ) +BLIS_INLINE bool bli_dt_dom_is_real( num_t dt ) +{ + return ( bool ) + ( ( dt & BLIS_DOMAIN_BIT ) == BLIS_REAL ); +} + +BLIS_INLINE bool bli_dt_dom_is_complex( num_t dt ) +{ + return ( bool ) + ( ( dt & BLIS_DOMAIN_BIT ) == BLIS_COMPLEX ); +} + +BLIS_INLINE prec_t bli_dt_prec( num_t dt ) { return ( prec_t ) ( dt & BLIS_PRECISION_BIT ); } -static num_t bli_dt_proj_to_real( num_t dt ) +BLIS_INLINE bool bli_dt_prec_is_single( num_t dt ) +{ + return ( bool ) + ( ( dt & BLIS_PRECISION_BIT ) == BLIS_SINGLE_PREC ); +} + +BLIS_INLINE bool bli_dt_prec_is_double( num_t dt ) +{ + return ( bool ) + ( ( dt & BLIS_PRECISION_BIT ) == BLIS_DOUBLE_PREC ); +} + +BLIS_INLINE num_t bli_dt_proj_to_real( num_t dt ) { return ( num_t ) ( dt & ~BLIS_BITVAL_COMPLEX ); } -static num_t bli_dt_proj_to_complex( num_t dt ) +BLIS_INLINE num_t bli_dt_proj_to_complex( num_t dt ) { return ( num_t ) ( dt | BLIS_BITVAL_COMPLEX ); } -static num_t bli_dt_proj_to_single_prec( num_t dt ) +BLIS_INLINE num_t bli_dt_proj_to_single_prec( num_t dt ) { return ( num_t ) ( dt & ~BLIS_BITVAL_DOUBLE_PREC ); } -static num_t bli_dt_proj_to_double_prec( num_t dt ) +BLIS_INLINE num_t bli_dt_proj_to_double_prec( num_t dt ) { return ( num_t ) ( dt | BLIS_BITVAL_DOUBLE_PREC ); @@ -165,79 +189,85 @@ static num_t bli_dt_proj_to_double_prec( num_t dt ) // trans -static bool_t bli_is_notrans( trans_t trans ) +BLIS_INLINE bool bli_is_notrans( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( trans == BLIS_NO_TRANSPOSE ); } -static bool_t bli_is_trans( trans_t trans ) +BLIS_INLINE bool bli_is_trans( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( trans == BLIS_TRANSPOSE ); } -static bool_t bli_is_conjnotrans( trans_t trans ) +BLIS_INLINE bool bli_is_conjnotrans( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( trans == BLIS_CONJ_NO_TRANSPOSE ); } -static bool_t bli_is_conjtrans( trans_t trans ) +BLIS_INLINE bool bli_is_conjtrans( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( trans == BLIS_CONJ_TRANSPOSE ); } -static bool_t bli_does_notrans( trans_t trans ) +BLIS_INLINE bool bli_does_notrans( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( (~trans & BLIS_TRANS_BIT ) == BLIS_BITVAL_TRANS ); } -static bool_t bli_does_trans( trans_t trans ) +BLIS_INLINE bool bli_does_trans( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( ( trans & BLIS_TRANS_BIT ) == BLIS_BITVAL_TRANS ); } -static bool_t bli_does_noconj( trans_t trans ) +BLIS_INLINE bool bli_does_noconj( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( (~trans & BLIS_CONJ_BIT ) == BLIS_BITVAL_CONJ ); } -static bool_t bli_does_conj( trans_t trans ) +BLIS_INLINE bool bli_does_conj( trans_t trans ) { - return ( bool_t ) + return ( bool ) ( ( trans & BLIS_CONJ_BIT ) == BLIS_BITVAL_CONJ ); } -static trans_t bli_extract_trans( trans_t trans ) +BLIS_INLINE trans_t bli_extract_trans( trans_t trans ) { return ( trans_t ) ( trans & BLIS_TRANS_BIT ); } -static conj_t bli_extract_conj( trans_t trans ) +BLIS_INLINE conj_t bli_extract_conj( trans_t trans ) { return ( conj_t ) ( trans & BLIS_CONJ_BIT ); } -static trans_t bli_trans_toggled( trans_t trans ) +BLIS_INLINE trans_t bli_trans_toggled( trans_t trans ) { return ( trans_t ) ( trans ^ BLIS_TRANS_BIT ); } -static trans_t bli_trans_toggled_conj( trans_t trans ) +BLIS_INLINE trans_t bli_trans_toggled_conj( trans_t trans ) { return ( trans_t ) ( trans ^ BLIS_CONJ_BIT ); } -static void bli_toggle_trans( trans_t* trans ) +BLIS_INLINE trans_t bli_apply_trans( trans_t transapp, trans_t trans ) +{ + return ( trans_t ) + ( trans ^ transapp ); +} + +BLIS_INLINE void bli_toggle_trans( trans_t* trans ) { *trans = bli_trans_toggled( *trans ); } @@ -245,24 +275,24 @@ static void bli_toggle_trans( trans_t* trans ) // side -static bool_t bli_is_left( side_t side ) +BLIS_INLINE bool bli_is_left( side_t side ) { - return ( bool_t ) + return ( bool ) ( side == BLIS_LEFT ); } -static bool_t bli_is_right( side_t side ) +BLIS_INLINE bool bli_is_right( side_t side ) { - return ( bool_t ) + return ( bool ) ( side == BLIS_RIGHT ); } -static side_t bli_side_toggled( side_t side ) +BLIS_INLINE side_t bli_side_toggled( side_t side ) { return ( bli_is_left( side ) ? BLIS_RIGHT : BLIS_LEFT ); } -static void bli_toggle_side( side_t* side ) +BLIS_INLINE void bli_toggle_side( side_t* side ) { *side = bli_side_toggled( *side ); } @@ -270,46 +300,47 @@ static void bli_toggle_side( side_t* side ) // uplo -static bool_t bli_is_lower( uplo_t uplo ) +BLIS_INLINE bool bli_is_lower( uplo_t uplo ) { - return ( bool_t ) + return ( bool ) ( uplo == BLIS_LOWER ); } -static bool_t bli_is_upper( uplo_t uplo ) +BLIS_INLINE bool bli_is_upper( uplo_t uplo ) { - return ( bool_t ) + return ( bool ) ( uplo == BLIS_UPPER ); } -static bool_t bli_is_upper_or_lower( uplo_t uplo ) +BLIS_INLINE bool bli_is_upper_or_lower( uplo_t uplo ) { - return ( bool_t ) + return ( bool ) ( bli_is_upper( uplo ) || - bli_is_lower( uplo ) ); + bli_is_lower( uplo ) ); } -static bool_t bli_is_dense( uplo_t uplo ) +BLIS_INLINE bool bli_is_dense( uplo_t uplo ) { - return ( bool_t ) + return ( bool ) ( uplo == BLIS_DENSE ); } -static bool_t bli_is_zeros( uplo_t uplo ) +BLIS_INLINE bool bli_is_zeros( uplo_t uplo ) { - return ( bool_t ) + return ( bool ) ( uplo == BLIS_ZEROS ); } -static uplo_t bli_uplo_toggled( uplo_t uplo ) +BLIS_INLINE uplo_t bli_uplo_toggled( uplo_t uplo ) { return ( uplo_t ) - ( bli_is_upper_or_lower( uplo ) ? - ( ( uplo ^ BLIS_LOWER_BIT ) ^ BLIS_UPPER_BIT ) : uplo + ( bli_is_upper_or_lower( uplo ) + ? ( ( uplo ^ BLIS_LOWER_BIT ) ^ BLIS_UPPER_BIT ) + : uplo ); } -static void bli_toggle_uplo( uplo_t* uplo ) +BLIS_INLINE void bli_toggle_uplo( uplo_t* uplo ) { *uplo = bli_uplo_toggled( *uplo ); } @@ -317,65 +348,65 @@ static void bli_toggle_uplo( uplo_t* uplo ) // structure -static bool_t bli_is_general( struc_t struc ) +BLIS_INLINE bool bli_is_general( struc_t struc ) { - return ( bool_t ) + return ( bool ) ( struc == BLIS_GENERAL ); } -static bool_t bli_is_hermitian( struc_t struc ) +BLIS_INLINE bool bli_is_hermitian( struc_t struc ) { - return ( bool_t ) + return ( bool ) ( struc == BLIS_HERMITIAN ); } -static bool_t bli_is_symmetric( struc_t struc ) +BLIS_INLINE bool bli_is_symmetric( struc_t struc ) { - return ( bool_t ) + return ( bool ) ( struc == BLIS_SYMMETRIC ); } -static bool_t bli_is_triangular( struc_t struc ) +BLIS_INLINE bool bli_is_triangular( struc_t struc ) { - return ( bool_t ) + return ( bool ) ( struc == BLIS_TRIANGULAR ); } -static bool_t bli_is_herm_or_symm( struc_t struc ) +BLIS_INLINE bool bli_is_herm_or_symm( struc_t struc ) { - return ( bool_t ) + return ( bool ) ( bli_is_hermitian( struc ) || - bli_is_symmetric( struc ) ); + bli_is_symmetric( struc ) ); } // conj -static bool_t bli_is_noconj( conj_t conj ) +BLIS_INLINE bool bli_is_noconj( conj_t conj ) { - return ( bool_t ) + return ( bool ) ( conj == BLIS_NO_CONJUGATE ); } -static bool_t bli_is_conj( conj_t conj ) +BLIS_INLINE bool bli_is_conj( conj_t conj ) { - return ( bool_t ) + return ( bool ) ( conj == BLIS_CONJUGATE ); } -static conj_t bli_conj_toggled( conj_t conj ) +BLIS_INLINE conj_t bli_conj_toggled( conj_t conj ) { return ( conj_t ) ( conj ^ BLIS_CONJ_BIT ); } -static conj_t bli_apply_conj( conj_t conjapp, conj_t conj ) +BLIS_INLINE conj_t bli_apply_conj( conj_t conjapp, conj_t conj ) { return ( conj_t ) ( conj ^ conjapp ); } -static void bli_toggle_conj( conj_t* conj ) +BLIS_INLINE void bli_toggle_conj( conj_t* conj ) { *conj = bli_conj_toggled( *conj ); } @@ -383,82 +414,97 @@ static void bli_toggle_conj( conj_t* conj ) // diag -static bool_t bli_is_nonunit_diag( diag_t diag ) +BLIS_INLINE bool bli_is_nonunit_diag( diag_t diag ) { - return ( bool_t ) + return ( bool ) ( diag == BLIS_NONUNIT_DIAG ); } -static bool_t bli_is_unit_diag( diag_t diag ) +BLIS_INLINE bool bli_is_unit_diag( diag_t diag ) { - return ( bool_t ) + return ( bool ) ( diag == BLIS_UNIT_DIAG ); } +// err_t-related + +BLIS_INLINE bool bli_is_success( err_t err ) +{ + return ( bool ) + ( err == BLIS_SUCCESS ); +} + +BLIS_INLINE bool bli_is_failure( err_t err ) +{ + return ( bool ) + ( err != BLIS_SUCCESS ); +} + + // dimension-related -static bool_t bli_zero_dim1( dim_t m ) +BLIS_INLINE bool bli_zero_dim1( dim_t m ) { - return ( bool_t ) + return ( bool ) ( m == 0 ); } -static bool_t bli_zero_dim2( dim_t m, dim_t n ) +BLIS_INLINE bool bli_zero_dim2( dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( m == 0 || n == 0 ); } -static bool_t bli_zero_dim3( dim_t m, dim_t n, dim_t k ) +BLIS_INLINE bool bli_zero_dim3( dim_t m, dim_t n, dim_t k ) { - return ( bool_t ) + return ( bool ) ( m == 0 || n == 0 || k == 0 ); } -static bool_t bli_nonzero_dim( dim_t m ) +BLIS_INLINE bool bli_nonzero_dim( dim_t m ) { - return ( bool_t ) + return ( bool ) ( m > 0 ); } -static bool_t bli_vector_dim( dim_t m, dim_t n ) +BLIS_INLINE bool bli_vector_dim( dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( m == 1 ? n : m ); } -static bool_t bli_is_vector( dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_vector( dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( m == 1 || n == 1 ); } -static bool_t bli_is_row_vector( dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_row_vector( dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( m == 1 ); } -static bool_t bli_is_col_vector( dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_col_vector( dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( n == 1 ); } -static void bli_set_dim_with_side( side_t side, dim_t m, dim_t n, dim_t* dim ) +BLIS_INLINE void bli_set_dim_with_side( side_t side, dim_t m, dim_t n, dim_t* dim ) { if ( bli_is_left( side ) ) *dim = m; else *dim = n; } -static void bli_set_dims_with_trans( trans_t trans, dim_t m, dim_t n, dim_t* mt, dim_t* nt ) +BLIS_INLINE void bli_set_dims_with_trans( trans_t trans, dim_t m, dim_t n, dim_t* mt, dim_t* nt ) { if ( bli_does_notrans( trans ) ) { *mt = m; *nt = n; } else { *mt = n; *nt = m; } } -static void bli_set_dims_incs_with_trans( trans_t trans, +BLIS_INLINE void bli_set_dims_incs_with_trans( trans_t trans, dim_t m, dim_t n, inc_t rs, inc_t cs, dim_t* mt, dim_t* nt, inc_t* rst, inc_t* cst ) { @@ -469,193 +515,193 @@ static void bli_set_dims_incs_with_trans( trans_t trans, // blocksize-related -static dim_t bli_determine_blocksize_dim_f( dim_t i, dim_t dim, dim_t b_alg ) +BLIS_INLINE dim_t bli_determine_blocksize_dim_f( dim_t i, dim_t dim, dim_t b_alg ) { return ( dim_t ) ( bli_min( b_alg, dim - i ) ); } -static dim_t bli_determine_blocksize_dim_b( dim_t i, dim_t dim, dim_t b_alg ) +BLIS_INLINE dim_t bli_determine_blocksize_dim_b( dim_t i, dim_t dim, dim_t b_alg ) { return ( dim_t ) ( i == 0 && dim % b_alg != 0 ? dim % b_alg - : b_alg ); + : b_alg ); } // stride-related -static inc_t bli_vector_inc( trans_t trans, dim_t m, dim_t n, inc_t rs, inc_t cs ) +BLIS_INLINE inc_t bli_vector_inc( trans_t trans, dim_t m, dim_t n, inc_t rs, inc_t cs ) { return ( inc_t ) ( bli_does_notrans( trans ) ? ( m == 1 ? cs : rs ) : ( m == 1 ? rs : cs ) ); } -static bool_t bli_is_row_stored( inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_row_stored( inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( bli_abs( cs ) == 1 ); } -static bool_t bli_is_col_stored( inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_col_stored( inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( bli_abs( rs ) == 1 ); } -static bool_t bli_is_row_stored_f( dim_t m, dim_t n, inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_row_stored_f( dim_t m, dim_t n, inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( cs == 1 && ( rs > 1 || n == 1 ) ); } -static bool_t bli_is_col_stored_f( dim_t m, dim_t n, inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_col_stored_f( dim_t m, dim_t n, inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( rs == 1 && ( cs > 1 || m == 1 ) ); } -static bool_t bli_is_gen_stored( inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_gen_stored( inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( bli_abs( rs ) != 1 && - bli_abs( cs ) != 1 ); + bli_abs( cs ) != 1 ); } -static bool_t bli_is_row_tilted( dim_t m, dim_t n, inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_row_tilted( dim_t m, dim_t n, inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( bli_abs( cs ) == bli_abs( rs ) - ? n < m - : bli_abs( cs ) < bli_abs( rs ) ); + ? n < m + : bli_abs( cs ) < bli_abs( rs ) ); } -static bool_t bli_is_col_tilted( dim_t m, dim_t n, inc_t rs, inc_t cs ) +BLIS_INLINE bool bli_is_col_tilted( dim_t m, dim_t n, inc_t rs, inc_t cs ) { - return ( bool_t ) + return ( bool ) ( bli_abs( rs ) == bli_abs( cs ) - ? m < n - : bli_abs( rs ) < bli_abs( cs ) ); + ? m < n + : bli_abs( rs ) < bli_abs( cs ) ); } -static bool_t bli_has_nonunit_inc1( inc_t s1 ) +BLIS_INLINE bool bli_has_nonunit_inc1( inc_t s1 ) { - return ( bool_t ) + return ( bool ) ( s1 != 1 ); } -static bool_t bli_has_nonunit_inc2( inc_t s1, inc_t s2 ) +BLIS_INLINE bool bli_has_nonunit_inc2( inc_t s1, inc_t s2 ) { - return ( bool_t ) + return ( bool ) ( s1 != 1 || s2 != 1 ); } -static bool_t bli_has_nonunit_inc3( inc_t s1, inc_t s2, inc_t s3 ) +BLIS_INLINE bool bli_has_nonunit_inc3( inc_t s1, inc_t s2, inc_t s3 ) { - return ( bool_t ) + return ( bool ) ( s1 != 1 || s2 != 1 || s3 != 1 ); } // diag offset-related -static void bli_negate_diag_offset( doff_t* diagoff ) +BLIS_INLINE void bli_negate_diag_offset( doff_t* diagoff ) { *diagoff = -(*diagoff); } -static void bli_shift_diag_offset_to_grow_uplo( uplo_t uplo, doff_t* diagoff ) +BLIS_INLINE void bli_shift_diag_offset_to_grow_uplo( uplo_t uplo, doff_t* diagoff ) { if ( bli_is_upper( uplo ) ) *diagoff -= 1; else if ( bli_is_lower( uplo ) ) *diagoff += 1; } -static void bli_shift_diag_offset_to_shrink_uplo( uplo_t uplo, doff_t* diagoff ) +BLIS_INLINE void bli_shift_diag_offset_to_shrink_uplo( uplo_t uplo, doff_t* diagoff ) { if ( bli_is_upper( uplo ) ) *diagoff += 1; else if ( bli_is_lower( uplo ) ) *diagoff -= 1; } -static bool_t bli_diag_offset_with_trans( trans_t trans, doff_t diagoff ) +BLIS_INLINE doff_t bli_diag_offset_with_trans( trans_t trans, doff_t diagoff ) { - return ( bool_t ) + return ( doff_t ) ( bli_does_trans( trans ) ? -diagoff : diagoff ); } -static bool_t bli_is_strictly_above_diag( doff_t diagoff, trans_t trans, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_strictly_above_diag( doff_t diagoff, trans_t trans, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( bli_does_trans( trans ) - ? ( ( doff_t )n <= -diagoff ) - : ( ( doff_t )m <= -diagoff ) ); + ? ( ( doff_t )n <= -diagoff ) + : ( ( doff_t )m <= -diagoff ) ); } -static bool_t bli_is_strictly_below_diag( doff_t diagoff, trans_t trans, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_strictly_below_diag( doff_t diagoff, trans_t trans, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( bli_does_trans( trans ) - ? ( ( doff_t )m <= diagoff ) - : ( ( doff_t )n <= diagoff ) ); + ? ( ( doff_t )m <= diagoff ) + : ( ( doff_t )n <= diagoff ) ); } -static bool_t bli_is_outside_diag( doff_t diagoff, trans_t trans, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_outside_diag( doff_t diagoff, trans_t trans, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( bli_is_strictly_above_diag( diagoff, trans, m, n ) || bli_is_strictly_below_diag( diagoff, trans, m, n ) ); } -static bool_t bli_is_stored_subpart( doff_t diagoff, trans_t trans, uplo_t uplo, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_stored_subpart( doff_t diagoff, trans_t trans, uplo_t uplo, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( ( bli_is_upper( uplo ) && bli_is_strictly_above_diag( diagoff, trans, m, n ) ) || ( bli_is_lower( uplo ) && bli_is_strictly_below_diag( diagoff, trans, m, n ) ) ); } -static bool_t bli_is_unstored_subpart( doff_t diagoff, trans_t trans, uplo_t uplo, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_unstored_subpart( doff_t diagoff, trans_t trans, uplo_t uplo, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( ( bli_is_upper( uplo ) && bli_is_strictly_below_diag( diagoff, trans, m, n ) ) || ( bli_is_lower( uplo ) && bli_is_strictly_above_diag( diagoff, trans, m, n ) ) ); } -static bool_t bli_is_strictly_above_diag_n( doff_t diagoff, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_strictly_above_diag_n( doff_t diagoff, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( ( doff_t )m <= -diagoff ); } -static bool_t bli_is_strictly_below_diag_n( doff_t diagoff, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_strictly_below_diag_n( doff_t diagoff, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( ( doff_t )n <= diagoff ); } -static bool_t bli_intersects_diag_n( doff_t diagoff, dim_t m, dim_t n ) +BLIS_INLINE bool bli_intersects_diag_n( doff_t diagoff, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( !bli_is_strictly_above_diag_n( diagoff, m, n ) && !bli_is_strictly_below_diag_n( diagoff, m, n ) ); } -static bool_t bli_is_outside_diag_n( doff_t diagoff, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_outside_diag_n( doff_t diagoff, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( bli_is_strictly_above_diag_n( diagoff, m, n ) || bli_is_strictly_below_diag_n( diagoff, m, n ) ); } -static bool_t bli_is_stored_subpart_n( doff_t diagoff, uplo_t uplo, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_stored_subpart_n( doff_t diagoff, uplo_t uplo, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( ( bli_is_upper( uplo ) && bli_is_strictly_above_diag_n( diagoff, m, n ) ) || ( bli_is_lower( uplo ) && bli_is_strictly_below_diag_n( diagoff, m, n ) ) ); } -static bool_t bli_is_unstored_subpart_n( doff_t diagoff, uplo_t uplo, dim_t m, dim_t n ) +BLIS_INLINE bool bli_is_unstored_subpart_n( doff_t diagoff, uplo_t uplo, dim_t m, dim_t n ) { - return ( bool_t ) + return ( bool ) ( ( bli_is_upper( uplo ) && bli_is_strictly_below_diag_n( diagoff, m, n ) ) || ( bli_is_lower( uplo ) && bli_is_strictly_above_diag_n( diagoff, m, n ) ) ); } @@ -663,7 +709,7 @@ static bool_t bli_is_unstored_subpart_n( doff_t diagoff, uplo_t uplo, dim_t m, d // pruning-related -static void bli_prune_unstored_region_top_l( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offm_inc ) +BLIS_INLINE void bli_prune_unstored_region_top_l( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offm_inc ) { *offm_inc = 0; @@ -677,7 +723,7 @@ static void bli_prune_unstored_region_top_l( doff_t* diagoff, dim_t* m, dim_t* n } } -static void bli_prune_unstored_region_right_l( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offn_inc ) +BLIS_INLINE void bli_prune_unstored_region_right_l( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offn_inc ) { *offn_inc = 0; @@ -689,7 +735,7 @@ static void bli_prune_unstored_region_right_l( doff_t* diagoff, dim_t* m, dim_t* } } -static void bli_prune_unstored_region_left_u( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offn_inc ) +BLIS_INLINE void bli_prune_unstored_region_left_u( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offn_inc ) { *offn_inc = 0; @@ -703,7 +749,7 @@ static void bli_prune_unstored_region_left_u( doff_t* diagoff, dim_t* m, dim_t* } } -static void bli_prune_unstored_region_bottom_u( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offm_inc ) +BLIS_INLINE void bli_prune_unstored_region_bottom_u( doff_t* diagoff, dim_t* m, dim_t* n, dim_t* offm_inc ) { *offm_inc = 0; @@ -718,20 +764,20 @@ static void bli_prune_unstored_region_bottom_u( doff_t* diagoff, dim_t* m, dim_t // thread range-related -static void bli_rotate180_trapezoid( doff_t* diagoff, uplo_t* uplo, dim_t* m, dim_t* n ) +BLIS_INLINE void bli_rotate180_trapezoid( doff_t* diagoff, uplo_t* uplo, dim_t* m, dim_t* n ) { *diagoff = *n - *diagoff - *m; bli_toggle_uplo( uplo ); } -static void bli_reflect_about_diag( doff_t* diagoff, uplo_t* uplo, dim_t* m, dim_t* n ) +BLIS_INLINE void bli_reflect_about_diag( doff_t* diagoff, uplo_t* uplo, dim_t* m, dim_t* n ) { bli_swap_dims( m, n ); bli_negate_diag_offset( diagoff ); bli_toggle_uplo( uplo ); } -static void bli_reverse_index_direction( dim_t n, dim_t* start, dim_t* end ) +BLIS_INLINE void bli_reverse_index_direction( dim_t n, dim_t* start, dim_t* end ) { dim_t start2 = n - *start; dim_t end2 = n - *end; @@ -742,69 +788,161 @@ static void bli_reverse_index_direction( dim_t n, dim_t* start, dim_t* end ) // mdim_t-related -static bool_t bli_is_m_dim( mdim_t mdim ) +BLIS_INLINE bool bli_is_m_dim( mdim_t mdim ) { - return ( bool_t ) + return ( bool ) ( mdim == BLIS_M ); } -static bool_t bli_is_n_dim( mdim_t mdim ) +BLIS_INLINE bool bli_is_n_dim( mdim_t mdim ) { - return ( bool_t ) + return ( bool ) ( mdim == BLIS_N ); } -static mdim_t bli_dim_toggled( mdim_t mdim ) +BLIS_INLINE mdim_t bli_dim_toggled( mdim_t mdim ) { - return ( mdim == BLIS_M ? BLIS_N : BLIS_M ); + return ( mdim_t ) + ( mdim == BLIS_M ? BLIS_N : BLIS_M ); } -static void bli_toggle_dim( mdim_t* mdim ) +BLIS_INLINE void bli_toggle_dim( mdim_t* mdim ) { *mdim = bli_dim_toggled( *mdim ); } +// stor3_t-related + +BLIS_INLINE stor3_t bli_stor3_from_strides( inc_t rs_c, inc_t cs_c, + inc_t rs_a, inc_t cs_a, + inc_t rs_b, inc_t cs_b ) +{ + // If any matrix is general-stored, return the stor3_t id for the + // general-purpose sup microkernel. + if ( bli_is_gen_stored( rs_c, cs_c ) || + bli_is_gen_stored( rs_a, cs_a ) || + bli_is_gen_stored( rs_b, cs_b ) ) return BLIS_XXX; + + // Otherwise, compute and return the stor3_t id as follows. + const bool c_is_col = bli_is_col_stored( rs_c, cs_c ); + const bool a_is_col = bli_is_col_stored( rs_a, cs_a ); + const bool b_is_col = bli_is_col_stored( rs_b, cs_b ); + + return ( stor3_t )( 4 * c_is_col + + 2 * a_is_col + + 1 * b_is_col ); +} + +BLIS_INLINE stor3_t bli_stor3_trans( stor3_t id ) +{ +#if 1 + stor3_t map[ BLIS_NUM_3OP_RC_COMBOS ] + = + { + ( stor3_t )7, // BLIS_RRR = 0 -> BLIS_CCC = 7 + ( stor3_t )5, // BLIS_RRC = 1 -> BLIS_CRC = 5 + ( stor3_t )6, // BLIS_RCR = 2 -> BLIS_CCR = 6 + ( stor3_t )4, // BLIS_RCC = 3 -> BLIS_CRR = 4 + ( stor3_t )3, // BLIS_CRR = 4 -> BLIS_RCC = 3 + ( stor3_t )1, // BLIS_CRC = 5 -> BLIS_RRC = 1 + ( stor3_t )2, // BLIS_CCR = 6 -> BLIS_RCR = 2 + ( stor3_t )0, // BLIS_CCC = 7 -> BLIS_RRR = 0 + }; + + return map[id]; +#else + return ( ( id & 0x4 ) ^ 0x4 ) | // flip c bit + ( ( ( id & 0x1 ) ^ 0x1 ) << 1 ) | // flip b bit and move to a position + ( ( ( id & 0x2 ) ^ 0x2 ) >> 1 ); // flip a bit and move to b position +#endif +} + +BLIS_INLINE stor3_t bli_stor3_transa( stor3_t id ) +{ +#if 0 + stor3_t map[ BLIS_NUM_3OP_RC_COMBOS ] + = + { + ( stor3_t )1, // BLIS_RRR = 0 -> BLIS_RRC = 1 + ( stor3_t )0, // BLIS_RRC = 1 -> BLIS_RRR = 0 + ( stor3_t )3, // BLIS_RCR = 2 -> BLIS_RCC = 3 + ( stor3_t )2, // BLIS_RCC = 3 -> BLIS_RCR = 2 + ( stor3_t )5, // BLIS_CRR = 4 -> BLIS_CRC = 5 + ( stor3_t )4, // BLIS_CRC = 5 -> BLIS_CRR = 4 + ( stor3_t )7, // BLIS_CCR = 6 -> BLIS_CCC = 7 + ( stor3_t )6, // BLIS_CCC = 7 -> BLIS_CCR = 6 + }; + + return map[id]; +#else + return ( stor3_t )( id ^ 0x1 ); +#endif +} + +BLIS_INLINE stor3_t bli_stor3_transb( stor3_t id ) +{ +#if 0 + stor3_t map[ BLIS_NUM_3OP_RC_COMBOS ] + = + { + ( stor3_t )2, // BLIS_RRR = 0 -> BLIS_RCR = 2 + ( stor3_t )3, // BLIS_RRC = 1 -> BLIS_RCC = 3 + ( stor3_t )0, // BLIS_RCR = 2 -> BLIS_RRR = 0 + ( stor3_t )1, // BLIS_RCC = 3 -> BLIS_RRC = 1 + ( stor3_t )6, // BLIS_CRR = 4 -> BLIS_CCR = 6 + ( stor3_t )7, // BLIS_CRC = 5 -> BLIS_CCC = 7 + ( stor3_t )4, // BLIS_CCR = 6 -> BLIS_CRR = 4 + ( stor3_t )5, // BLIS_CCC = 7 -> BLIS_CRC = 5 + }; + + return map[id]; +#else + return ( stor3_t )( id ^ 0x2 ); +#endif +} + + // index-related -static bool_t bli_is_edge_f( dim_t i, dim_t n_iter, dim_t n_left ) +BLIS_INLINE bool bli_is_edge_f( dim_t i, dim_t n_iter, dim_t n_left ) { - return ( bool_t ) + return ( bool ) ( i == n_iter - 1 && n_left != 0 ); } -static bool_t bli_is_not_edge_f( dim_t i, dim_t n_iter, dim_t n_left ) +BLIS_INLINE bool bli_is_not_edge_f( dim_t i, dim_t n_iter, dim_t n_left ) { - return ( bool_t ) + return ( bool ) ( i != n_iter - 1 || n_left == 0 ); } -static bool_t bli_is_edge_b( dim_t i, dim_t n_iter, dim_t n_left ) +BLIS_INLINE bool bli_is_edge_b( dim_t i, dim_t n_iter, dim_t n_left ) { - return ( bool_t ) + return ( bool ) ( i == 0 && n_left != 0 ); } -static bool_t bli_is_not_edge_b( dim_t i, dim_t n_iter, dim_t n_left ) +BLIS_INLINE bool bli_is_not_edge_b( dim_t i, dim_t n_iter, dim_t n_left ) { - return ( bool_t ) + return ( bool ) ( i != 0 || n_left == 0 ); } -static bool_t bli_is_last_iter_sl( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) +BLIS_INLINE bool bli_is_last_iter_sl( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) { - return ( bool_t ) + return ( bool ) ( i == end_iter - 1 ); } -static bool_t bli_is_last_iter_rr( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) +BLIS_INLINE bool bli_is_last_iter_rr( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) { - return ( bool_t ) + return ( bool ) ( i == end_iter - 1 - ( ( end_iter - tid - 1 ) % nth ) ); } -static bool_t bli_is_last_iter( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) +BLIS_INLINE bool bli_is_last_iter( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) { #ifdef BLIS_ENABLE_JRIR_SLAB return bli_is_last_iter_sl( i, end_iter, tid, nth ); @@ -816,7 +954,7 @@ static bool_t bli_is_last_iter( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) // packbuf_t-related -static guint_t bli_packbuf_index( packbuf_t buf_type ) +BLIS_INLINE guint_t bli_packbuf_index( packbuf_t buf_type ) { return ( guint_t ) ( ( buf_type & BLIS_PACK_BUFFER_BITS ) >> BLIS_PACK_BUFFER_SHIFT ); @@ -824,144 +962,74 @@ static guint_t bli_packbuf_index( packbuf_t buf_type ) // pack_t-related -static bool_t bli_is_packed( pack_t schema ) +BLIS_INLINE bool bli_is_packed( pack_t schema ) { - return ( bool_t ) + return ( bool ) ( schema & BLIS_PACK_BIT ); } -static bool_t bli_is_row_packed( pack_t schema ) +BLIS_INLINE bool bli_is_row_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ - BLIS_BITVAL_PACKED_ROWS ); + return ( bool ) + ( ( schema & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ + BLIS_BITVAL_PACKED_ROWS ) ); } -static bool_t bli_is_col_packed( pack_t schema ) +BLIS_INLINE bool bli_is_col_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ - BLIS_BITVAL_PACKED_COLUMNS ); + return ( bool ) + ( ( schema & BLIS_PACK_RC_BIT ) == ( BLIS_BITVAL_PACKED_UNSPEC ^ + BLIS_BITVAL_PACKED_COLUMNS ) ); } -static bool_t bli_is_panel_packed( pack_t schema ) +BLIS_INLINE bool bli_is_panel_packed( pack_t schema ) { - return ( bool_t ) + return ( bool ) ( schema & BLIS_PACK_PANEL_BIT ); } -static bool_t bli_is_4mi_packed( pack_t schema ) -{ - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_4MI; -} - -static bool_t bli_is_3mi_packed( pack_t schema ) -{ - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_3MI; -} - -static bool_t bli_is_3ms_packed( pack_t schema ) -{ - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_3MS; -} - -static bool_t bli_is_ro_packed( pack_t schema ) +BLIS_INLINE bool bli_is_1r_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_RO; + return ( bool ) + ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_1R ); } -static bool_t bli_is_io_packed( pack_t schema ) +BLIS_INLINE bool bli_is_1e_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_IO; + return ( bool ) + ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_1E ); } -static bool_t bli_is_rpi_packed( pack_t schema ) +BLIS_INLINE bool bli_is_1m_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_RPI; -} - -static bool_t bli_is_rih_packed( pack_t schema ) -{ - return ( bool_t ) - ( bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ); -} - -static bool_t bli_is_1r_packed( pack_t schema ) -{ - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_1R; -} - -static bool_t bli_is_1e_packed( pack_t schema ) -{ - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_1E; -} - -static bool_t bli_is_1m_packed( pack_t schema ) -{ - return ( bool_t ) + return ( bool ) ( bli_is_1r_packed( schema ) || bli_is_1e_packed( schema ) ); } -static bool_t bli_is_nat_packed( pack_t schema ) +BLIS_INLINE bool bli_is_nat_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) == 0; + return ( bool ) + ( ( schema & BLIS_PACK_FORMAT_BITS ) == 0 ); } -static bool_t bli_is_ind_packed( pack_t schema ) +BLIS_INLINE bool bli_is_ind_packed( pack_t schema ) { - return ( bool_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) != 0; + return ( bool ) + ( ( schema & BLIS_PACK_FORMAT_BITS ) != 0 ); } -static guint_t bli_pack_schema_index( pack_t schema ) +BLIS_INLINE guint_t bli_pack_schema_index( pack_t schema ) { return ( guint_t ) - ( schema & BLIS_PACK_FORMAT_BITS ) >> BLIS_PACK_FORMAT_SHIFT; -} - - - -// pointer-related - -// Increment a pointer by an integer fraction: -// p0 + (num/dem) -// where p0 is a pointer to a datatype of size sizeof_p0. -static void* bli_ptr_inc_by_frac( void* p0, siz_t sizeof_p0, dim_t num, dim_t den ) -{ - return ( void* ) - ( ( char* )p0 + ( ( num * ( dim_t )sizeof_p0 ) / den ) ); -} - -static bool_t bli_is_null( void* p ) -{ - return ( bool_t ) - ( p == NULL ); -} - -static bool_t bli_is_nonnull( void* p ) -{ - return ( bool_t ) - ( p != NULL ); + ( ( schema & BLIS_PACK_FORMAT_BITS ) >> BLIS_PACK_FORMAT_SHIFT ); } // Set dimensions, increments, effective uplo/diagoff, etc for ONE matrix // argument. -static -void bli_set_dims_incs_uplo_1m +BLIS_INLINE void bli_set_dims_incs_uplo_1m ( doff_t diagoffa, diag_t diaga, uplo_t uploa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1055,8 +1123,7 @@ void bli_set_dims_incs_uplo_1m // Set dimensions, increments, effective uplo/diagoff, etc for ONE matrix // argument (without column-wise stride optimization). -static -void bli_set_dims_incs_uplo_1m_noswap +BLIS_INLINE void bli_set_dims_incs_uplo_1m_noswap ( doff_t diagoffa, diag_t diaga, uplo_t uploa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1141,8 +1208,7 @@ void bli_set_dims_incs_uplo_1m_noswap // Set dimensions and increments for TWO matrix arguments. -static -void bli_set_dims_incs_2m +BLIS_INLINE void bli_set_dims_incs_2m ( trans_t transa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1177,8 +1243,7 @@ void bli_set_dims_incs_2m // Set dimensions, increments, effective uplo/diagoff, etc for TWO matrix // arguments. -static -void bli_set_dims_incs_uplo_2m +BLIS_INLINE void bli_set_dims_incs_uplo_2m ( doff_t diagoffa, diag_t diaga, trans_t transa, uplo_t uploa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1285,8 +1350,7 @@ void bli_set_dims_incs_uplo_2m // Set dimensions, increments, etc for ONE matrix argument when operating // on the diagonal. -static -void bli_set_dims_incs_1d +BLIS_INLINE void bli_set_dims_incs_1d ( doff_t diagoffx, dim_t m, dim_t n, inc_t rs_x, inc_t cs_x, @@ -1309,8 +1373,7 @@ void bli_set_dims_incs_1d // Set dimensions, increments, etc for TWO matrix arguments when operating // on diagonals. -static -void bli_set_dims_incs_2d +BLIS_INLINE void bli_set_dims_incs_2d ( doff_t diagoffx, trans_t transx, dim_t m, dim_t n, inc_t rs_x, inc_t cs_x, diff --git a/frame/include/bli_scalar_macro_defs.h b/frame/include/bli_scalar_macro_defs.h index 23f62efeaa..293c80f910 100644 --- a/frame/include/bli_scalar_macro_defs.h +++ b/frame/include/bli_scalar_macro_defs.h @@ -195,39 +195,15 @@ #include "bli_adds_mxn_uplo.h" #include "bli_set0s_mxn.h" #include "bli_copys_mxn.h" +#include "bli_scal2s_mxn.h" #include "bli_xpbys_mxn.h" #include "bli_xpbys_mxn_uplo.h" +// -- "broadcast B" scalar macros -- -// -- 3m-specific scalar macros -- - -#include "bli_copyri3s.h" -#include "bli_copyjri3s.h" - -#include "bli_scal2ri3s.h" -#include "bli_scal2jri3s.h" - -#include "bli_scal2ri3s_mxn.h" - - -// -- 4mh/3mh-specific scalar macros -- - -// ro -#include "bli_scal2ros.h" -#include "bli_scal2jros.h" - -// io -#include "bli_scal2ios.h" -#include "bli_scal2jios.h" - -// rpi -#include "bli_scal2rpis.h" -#include "bli_scal2jrpis.h" - -#include "bli_scal2rihs_mxn.h" -#include "bli_scal2rihs_mxn_diag.h" -#include "bli_scal2rihs_mxn_uplo.h" -#include "bli_setrihs_mxn_diag.h" +#include "bli_bcastbbs_mxn.h" +#include "bli_scal2bbs_mxn.h" +#include "bli_set0bbs_mxn.h" // -- 1m-specific scalar macros -- diff --git a/frame/include/bli_system.h b/frame/include/bli_system.h index 173bbe1eda..79333017b9 100644 --- a/frame/include/bli_system.h +++ b/frame/include/bli_system.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,6 +36,12 @@ #ifndef BLIS_SYSTEM_H #define BLIS_SYSTEM_H +// NOTE: If not yet defined, we define _POSIX_C_SOURCE to make sure that +// various parts of POSIX are defined and made available. +#ifndef _POSIX_C_SOURCE +#define _POSIX_C_SOURCE 200809L +#endif + #include #include #include @@ -57,34 +63,39 @@ // Determine if we are on a 64-bit or 32-bit architecture. #if defined(_M_X64) || defined(__x86_64) || defined(__aarch64__) || \ - defined(_ARCH_PPC64) + defined(_ARCH_PPC64) || defined(__s390x__) || defined(_LP64) #define BLIS_ARCH_64 #else #define BLIS_ARCH_32 #endif // Determine the target operating system. -#if defined(_WIN32) || defined(__CYGWIN__) - #define BLIS_OS_WINDOWS 1 -#elif defined(__gnu_hurd__) - #define BLIS_OS_GNU 1 -#elif defined(__APPLE__) || defined(__MACH__) - #define BLIS_OS_OSX 1 -#elif defined(__ANDROID__) - #define BLIS_OS_ANDROID 1 -#elif defined(__linux__) - #define BLIS_OS_LINUX 1 -#elif defined(__bgq__) - #define BLIS_OS_BGQ 1 -#elif defined(__bg__) - #define BLIS_OS_BGP 1 -#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ - defined(__bsdi__) || defined(__DragonFly__) || defined(__FreeBSD_kernel__) - #define BLIS_OS_BSD 1 -#elif defined(EMSCRIPTEN) - #define BLIS_OS_EMSCRIPTEN -#else - #error "Cannot determine operating system" +#if defined(BLIS_ENABLE_SYSTEM) + #if defined(_WIN32) || defined(__CYGWIN__) + #define BLIS_OS_WINDOWS 1 + #elif defined(__gnu_hurd__) + #define BLIS_OS_GNU 1 + #elif defined(__APPLE__) || defined(__MACH__) + #define BLIS_OS_OSX 1 + #elif defined(__ANDROID__) + #define BLIS_OS_ANDROID 1 + #elif defined(__linux__) + #define BLIS_OS_LINUX 1 + #elif defined(__bgq__) + #define BLIS_OS_BGQ 1 + #elif defined(__bg__) + #define BLIS_OS_BGP 1 + #elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + defined(__bsdi__) || defined(__DragonFly__) || \ + defined(__FreeBSD_kernel__) || defined(__HAIKU__) + #define BLIS_OS_BSD 1 + #elif defined(EMSCRIPTEN) + #define BLIS_OS_EMSCRIPTEN + #else + #error "Cannot determine operating system" + #endif +#else // #if defined(BLIS_DISABLE_SYSTEM) + #define BLIS_OS_NONE #endif // A few changes that may be necessary in Windows environments. @@ -111,14 +122,10 @@ #elif BLIS_OS_OSX #include #else - #include + //#include + #include #endif -// POSIX threads are unconditionally required, regardless of whether -// multithreading is enabled via pthreads or OpenMP (or disabled). -// If pthreads is not available (Windows), then fake it. -//#include "bli_pthread_wrap.h" - #endif diff --git a/frame/include/bli_tapi_ba.h b/frame/include/bli_tapi_ba.h index 26356afe82..0177985d9d 100644 --- a/frame/include/bli_tapi_ba.h +++ b/frame/include/bli_tapi_ba.h @@ -35,7 +35,12 @@ // This file defines macros used to allow the _tapi.c files to produce // typed APIs that omit expert parameters. -// Define the macro to remove the function name suffix (in function +// Define a macro that allows the source code to determine which interface +// (basic or expert) we are compiling. +#undef BLIS_TAPI_BASIC +#define BLIS_TAPI_BASIC + +// Define the macro to omit a suffix from the function names (in function // definitions). #undef EX_SUF #define EX_SUF @@ -45,14 +50,10 @@ #undef BLIS_TAPI_EX_PARAMS #define BLIS_TAPI_EX_PARAMS -// Define the macro to declare local expert variables that are initialized +// Define the macro to add local expert variables that are initialized // to NULL. The "( void )" statements are to prevent unused variable // warnings by the compiler. #undef BLIS_TAPI_EX_DECLS #define BLIS_TAPI_EX_DECLS cntx_t* cntx = NULL; ( void )cntx; \ rntm_t* rntm = NULL; ( void )rntm; -// Define the macro to pass the local expert variables to another function. -//#undef BLIS_TAPI_EX_VARS -//#define BLIS_TAPI_EX_VARS - diff --git a/frame/include/bli_tapi_ex.h b/frame/include/bli_tapi_ex.h index 0e1b09226c..c999b0ae9e 100644 --- a/frame/include/bli_tapi_ex.h +++ b/frame/include/bli_tapi_ex.h @@ -35,8 +35,13 @@ // This file defines macros used to allow the _tapi.c files to produce // typed APIs that contain context parameters. -// Define the macro to add a suffix to the typed API function names -// (in function definitions). +// Define a macro that allows the source code to determine which interface +// (basic or expert) we are compiling. +#undef BLIS_TAPI_EXPERT +#define BLIS_TAPI_EXPERT + +// Define the macro to add a suffix to the function names (in function +// definitions). #undef EX_SUF #define EX_SUF BLIS_TAPI_EX_SUF @@ -50,7 +55,3 @@ #undef BLIS_TAPI_EX_DECLS #define BLIS_TAPI_EX_DECLS -// Define the macro to pass the local expert variables to another function. -//#undef BLIS_TAPI_EX_VARS -//#define BLIS_TAPI_EX_VARS ,cntx, rntm - diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index d2c3cb1895..c66505bde8 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -48,6 +48,7 @@ #elif __STDC_VERSION__ >= 199901L // For C99 (or later), include stdint.h. #include + #include #else // When stdint.h is not available, manually typedef the types we will use. #ifdef _WIN32 @@ -87,20 +88,20 @@ typedef unsigned long int guint_t; // -- Boolean type -- -typedef gint_t bool_t; - - -// -- Boolean values -- +// NOTE: bool_t is no longer used and has been replaced with C99's bool type. +//typedef bool bool_t; +// BLIS uses TRUE and FALSE macro constants as possible boolean values, but we +// define these macros in terms of true and false, respectively, which are +// defined by C99 in stdbool.h. #ifndef TRUE - #define TRUE 1 + #define TRUE true #endif #ifndef FALSE - #define FALSE 0 + #define FALSE false #endif - // -- Special-purpose integers -- // This cpp guard provides a temporary hack to allow libflame @@ -149,7 +150,7 @@ typedef uint32_t objbits_t; // object information bit field // interoperability with BLIS. #ifndef _DEFINED_SCOMPLEX #define _DEFINED_SCOMPLEX - typedef struct + typedef struct scomplex { float real; float imag; @@ -160,7 +161,7 @@ typedef uint32_t objbits_t; // object information bit field // interoperability with BLIS. #ifndef _DEFINED_DCOMPLEX #define _DEFINED_DCOMPLEX - typedef struct + typedef struct dcomplex { double real; double imag; @@ -197,6 +198,19 @@ typedef double f77_double; typedef scomplex f77_scomplex; typedef dcomplex f77_dcomplex; +// -- Misc. function pointer types -- + +// Note: This type should be used in any situation where the address of a +// *function* will be conveyed or stored prior to it being typecast back +// to the correct function type. It does not need to be used when conveying +// or storing the address of *data* (such as an array of float or double). +//typedef void (*void_fp)( void ); +typedef void* void_fp; + +// Typedef function pointer types for malloc() and free() substitutes. +typedef void* (*malloc_ft)( size_t size ); +typedef void (*free_ft) ( void* p ); + // // -- BLIS info bit field offsets ---------------------------------------------- @@ -234,24 +248,10 @@ typedef dcomplex f77_dcomplex; - 1 0000 01: packed by columns - 1 0000 10: packed by row panels - 1 0000 11: packed by column panels - - 1 0001 10: packed by 4m interleaved row panels - - 1 0001 11: packed by 4m interleaved column panels - - 1 0010 10: packed by 3m interleaved row panels - - 1 0010 11: packed by 3m interleaved column panels - - 1 0011 10: packed by 4m separated row panels (not used) - - 1 0011 11: packed by 4m separated column panels (not used) - - 1 0100 10: packed by 3m separated row panels - - 1 0100 11: packed by 3m separated column panels - - 1 0101 10: packed real-only row panels - - 1 0101 11: packed real-only column panels - - 1 0110 10: packed imag-only row panels - - 1 0110 11: packed imag-only column panels - - 1 0111 10: packed real+imag row panels - - 1 0111 11: packed real+imag column panels - - 1 1000 10: packed by 1m expanded row panels - - 1 1000 11: packed by 1m expanded column panels - - 1 1001 10: packed by 1m reordered row panels - - 1 1001 11: packed by 1m reordered column panels + - 1 0001 10: packed by 1m expanded row panels + - 1 0001 11: packed by 1m expanded column panels + - 1 0010 10: packed by 1m reordered row panels + - 1 0010 11: packed by 1m reordered column panels 23 Packed panel order if upper-stored - 0 == forward order if upper - 1 == reverse order if upper @@ -371,7 +371,7 @@ typedef dcomplex f77_dcomplex; #define BLIS_BITVAL_SINGLE_PREC 0x0 #define BLIS_BITVAL_DOUBLE_PREC BLIS_PRECISION_BIT #define BLIS_BITVAL_FLOAT_TYPE 0x0 -#define BLIS_BITVAL_SCOMPLEX_TYPE BLIS_DOMAIN_BIT +#define BLIS_BITVAL_SCOMPLEX_TYPE BLIS_DOMAIN_BIT #define BLIS_BITVAL_DOUBLE_TYPE BLIS_PRECISION_BIT #define BLIS_BITVAL_DCOMPLEX_TYPE ( BLIS_DOMAIN_BIT | BLIS_PRECISION_BIT ) #define BLIS_BITVAL_INT_TYPE 0x04 @@ -381,42 +381,21 @@ typedef dcomplex f77_dcomplex; #define BLIS_BITVAL_NO_CONJ 0x0 #define BLIS_BITVAL_CONJ BLIS_CONJ_BIT #define BLIS_BITVAL_CONJ_TRANS ( BLIS_CONJ_BIT | BLIS_TRANS_BIT ) -#define BLIS_BITVAL_ZEROS 0x0 +#define BLIS_BITVAL_ZEROS 0x0 #define BLIS_BITVAL_UPPER ( BLIS_UPPER_BIT | BLIS_DIAG_BIT ) #define BLIS_BITVAL_LOWER ( BLIS_LOWER_BIT | BLIS_DIAG_BIT ) -#define BLIS_BITVAL_DENSE BLIS_UPLO_BITS +#define BLIS_BITVAL_DENSE BLIS_UPLO_BITS #define BLIS_BITVAL_NONUNIT_DIAG 0x0 #define BLIS_BITVAL_UNIT_DIAG BLIS_UNIT_DIAG_BIT #define BLIS_BITVAL_INVERT_DIAG BLIS_INVERT_DIAG_BIT #define BLIS_BITVAL_NOT_PACKED 0x0 -#define BLIS_BITVAL_4MI ( 0x1 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_3MI ( 0x2 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_4MS ( 0x3 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_3MS ( 0x4 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_RO ( 0x5 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_IO ( 0x6 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_RPI ( 0x7 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_1E ( 0x8 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_1R ( 0x9 << BLIS_PACK_FORMAT_SHIFT ) +#define BLIS_BITVAL_1E ( 0x1 << BLIS_PACK_FORMAT_SHIFT ) +#define BLIS_BITVAL_1R ( 0x2 << BLIS_PACK_FORMAT_SHIFT ) #define BLIS_BITVAL_PACKED_UNSPEC ( BLIS_PACK_BIT ) #define BLIS_BITVAL_PACKED_ROWS ( BLIS_PACK_BIT ) #define BLIS_BITVAL_PACKED_COLUMNS ( BLIS_PACK_BIT | BLIS_PACK_RC_BIT ) #define BLIS_BITVAL_PACKED_ROW_PANELS ( BLIS_PACK_BIT | BLIS_PACK_PANEL_BIT ) #define BLIS_BITVAL_PACKED_COL_PANELS ( BLIS_PACK_BIT | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_4MI ( BLIS_PACK_BIT | BLIS_BITVAL_4MI | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_4MI ( BLIS_PACK_BIT | BLIS_BITVAL_4MI | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_3MI ( BLIS_PACK_BIT | BLIS_BITVAL_3MI | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_3MI ( BLIS_PACK_BIT | BLIS_BITVAL_3MI | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_4MS ( BLIS_PACK_BIT | BLIS_BITVAL_4MS | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_4MS ( BLIS_PACK_BIT | BLIS_BITVAL_4MS | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_3MS ( BLIS_PACK_BIT | BLIS_BITVAL_3MS | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_3MS ( BLIS_PACK_BIT | BLIS_BITVAL_3MS | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_RO ( BLIS_PACK_BIT | BLIS_BITVAL_RO | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_RO ( BLIS_PACK_BIT | BLIS_BITVAL_RO | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_IO ( BLIS_PACK_BIT | BLIS_BITVAL_IO | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_IO ( BLIS_PACK_BIT | BLIS_BITVAL_IO | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_RPI ( BLIS_PACK_BIT | BLIS_BITVAL_RPI | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_RPI ( BLIS_PACK_BIT | BLIS_BITVAL_RPI | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) #define BLIS_BITVAL_PACKED_ROW_PANELS_1E ( BLIS_PACK_BIT | BLIS_BITVAL_1E | BLIS_PACK_PANEL_BIT ) #define BLIS_BITVAL_PACKED_COL_PANELS_1E ( BLIS_PACK_BIT | BLIS_BITVAL_1E | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) #define BLIS_BITVAL_PACKED_ROW_PANELS_1R ( BLIS_PACK_BIT | BLIS_BITVAL_1R | BLIS_PACK_PANEL_BIT ) @@ -528,20 +507,6 @@ typedef enum BLIS_PACKED_COLUMNS = BLIS_BITVAL_PACKED_COLUMNS, BLIS_PACKED_ROW_PANELS = BLIS_BITVAL_PACKED_ROW_PANELS, BLIS_PACKED_COL_PANELS = BLIS_BITVAL_PACKED_COL_PANELS, - BLIS_PACKED_ROW_PANELS_4MI = BLIS_BITVAL_PACKED_ROW_PANELS_4MI, - BLIS_PACKED_COL_PANELS_4MI = BLIS_BITVAL_PACKED_COL_PANELS_4MI, - BLIS_PACKED_ROW_PANELS_3MI = BLIS_BITVAL_PACKED_ROW_PANELS_3MI, - BLIS_PACKED_COL_PANELS_3MI = BLIS_BITVAL_PACKED_COL_PANELS_3MI, - BLIS_PACKED_ROW_PANELS_4MS = BLIS_BITVAL_PACKED_ROW_PANELS_4MS, - BLIS_PACKED_COL_PANELS_4MS = BLIS_BITVAL_PACKED_COL_PANELS_4MS, - BLIS_PACKED_ROW_PANELS_3MS = BLIS_BITVAL_PACKED_ROW_PANELS_3MS, - BLIS_PACKED_COL_PANELS_3MS = BLIS_BITVAL_PACKED_COL_PANELS_3MS, - BLIS_PACKED_ROW_PANELS_RO = BLIS_BITVAL_PACKED_ROW_PANELS_RO, - BLIS_PACKED_COL_PANELS_RO = BLIS_BITVAL_PACKED_COL_PANELS_RO, - BLIS_PACKED_ROW_PANELS_IO = BLIS_BITVAL_PACKED_ROW_PANELS_IO, - BLIS_PACKED_COL_PANELS_IO = BLIS_BITVAL_PACKED_COL_PANELS_IO, - BLIS_PACKED_ROW_PANELS_RPI = BLIS_BITVAL_PACKED_ROW_PANELS_RPI, - BLIS_PACKED_COL_PANELS_RPI = BLIS_BITVAL_PACKED_COL_PANELS_RPI, BLIS_PACKED_ROW_PANELS_1E = BLIS_BITVAL_PACKED_ROW_PANELS_1E, BLIS_PACKED_COL_PANELS_1E = BLIS_BITVAL_PACKED_COL_PANELS_1E, BLIS_PACKED_ROW_PANELS_1R = BLIS_BITVAL_PACKED_ROW_PANELS_1R, @@ -549,10 +514,8 @@ typedef enum } pack_t; // We combine row and column packing into one "type", and we start -// with BLIS_PACKED_ROW_PANELS, _COLUMN_PANELS. We also count the -// schema pair for "4ms" (4m separated), because its bit value has -// been reserved, even though we don't use it. -#define BLIS_NUM_PACK_SCHEMA_TYPES 10 +// with BLIS_PACKED_ROW_PANELS, _COLUMN_PANELS. +#define BLIS_NUM_PACK_SCHEMA_TYPES 3 // -- Pack order type -- @@ -645,12 +608,7 @@ typedef enum typedef enum { - BLIS_3MH = 0, - BLIS_3M1, - BLIS_4MH, - BLIS_4M1B, - BLIS_4M1A, - BLIS_1M, + BLIS_1M = 0, BLIS_NAT, BLIS_IND_FIRST = 0, BLIS_IND_LAST = BLIS_NAT @@ -658,13 +616,8 @@ typedef enum #define BLIS_NUM_IND_METHODS (BLIS_NAT+1) -// These are used in bli_*_oapi.c to construct the ind_t values from +// These are used in bli_l3_*_oapi.c to construct the ind_t values from // the induced method substrings that go into function names. -#define bli_3mh BLIS_3MH -#define bli_3m1 BLIS_3M1 -#define bli_4mh BLIS_4MH -#define bli_4mb BLIS_4M1B -#define bli_4m1 BLIS_4M1A #define bli_1m BLIS_1M #define bli_nat BLIS_NAT @@ -801,6 +754,80 @@ typedef enum #define BLIS_NUM_UKR_IMPL_TYPES 4 +#if 0 +typedef enum +{ + // RV = row-stored, contiguous vector-loading + // RG = row-stored, non-contiguous gather-loading + // CV = column-stored, contiguous vector-loading + // CG = column-stored, non-contiguous gather-loading + + // RD = row-stored, dot-based + // CD = col-stored, dot-based + + // RC = row-stored, column-times-column + // CR = column-stored, row-times-row + + // GX = general-stored generic implementation + + BLIS_GEMMSUP_RV_UKR = 0, + BLIS_GEMMSUP_RG_UKR, + BLIS_GEMMSUP_CV_UKR, + BLIS_GEMMSUP_CG_UKR, + + BLIS_GEMMSUP_RD_UKR, + BLIS_GEMMSUP_CD_UKR, + + BLIS_GEMMSUP_RC_UKR, + BLIS_GEMMSUP_CR_UKR, + + BLIS_GEMMSUP_GX_UKR, +} l3sup_t; + +#define BLIS_NUM_LEVEL3_SUP_UKRS 9 +#endif + + +typedef enum +{ + // 3-operand storage combinations + BLIS_RRR = 0, + BLIS_RRC, // 1 + BLIS_RCR, // 2 + BLIS_RCC, // 3 + BLIS_CRR, // 4 + BLIS_CRC, // 5 + BLIS_CCR, // 6 + BLIS_CCC, // 7 + BLIS_XXX, // 8 + +#if 0 + BLIS_RRG, + BLIS_RCG, + BLIS_RGR, + BLIS_RGC, + BLIS_RGG, + BLIS_CRG, + BLIS_CCG, + BLIS_CGR, + BLIS_CGC, + BLIS_CGG, + BLIS_GRR, + BLIS_GRC, + BLIS_GRG, + BLIS_GCR, + BLIS_GCC, + BLIS_GCG, + BLIS_GGR, + BLIS_GGC, + BLIS_GGG, +#endif +} stor3_t; + +#define BLIS_NUM_3OP_RC_COMBOS 9 +//#define BLIS_NUM_3OP_RCG_COMBOS 27 + + #if 0 typedef enum { @@ -833,6 +860,7 @@ typedef enum // bli_l3_ind.c to index into arrays. // BLIS_GEMM = 0, + BLIS_GEMMT, BLIS_HEMM, BLIS_HERK, BLIS_HER2K, @@ -846,7 +874,7 @@ typedef enum BLIS_NOID } opid_t; -#define BLIS_NUM_LEVEL3_OPS 10 +#define BLIS_NUM_LEVEL3_OPS 11 // -- Blocksize ID type -- @@ -863,8 +891,10 @@ typedef enum BLIS_MC, BLIS_KC, BLIS_NC, + BLIS_M2, // level-2 blocksize in m dimension BLIS_N2, // level-2 blocksize in n dimension + BLIS_AF, // level-1f axpyf fusing factor BLIS_DF, // level-1f dotxf fusing factor BLIS_XF, // level-1f dotxaxpyf fusing factor @@ -875,6 +905,19 @@ typedef enum #define BLIS_NUM_BLKSZS 11 +// -- Threshold ID type -- + +typedef enum +{ + BLIS_MT = 0, // level-3 small/unpacked matrix threshold in m dimension + BLIS_NT, // level-3 small/unpacked matrix threshold in n dimension + BLIS_KT // level-3 small/unpacked matrix threshold in k dimension + +} threshid_t; + +#define BLIS_NUM_THRESH 3 + + // -- Architecture ID type -- // NOTE: This typedef enum must be kept up-to-date with the arch_t @@ -884,8 +927,11 @@ typedef enum typedef enum { + // NOTE: The C language standard guarantees that the first enum value + // starts at 0. + // Intel - BLIS_ARCH_SKX = 0, + BLIS_ARCH_SKX, BLIS_ARCH_KNL, BLIS_ARCH_KNC, BLIS_ARCH_HASWELL, @@ -893,6 +939,8 @@ typedef enum BLIS_ARCH_PENRYN, // AMD + BLIS_ARCH_ZEN3, + BLIS_ARCH_ZEN2, BLIS_ARCH_ZEN, BLIS_ARCH_EXCAVATOR, BLIS_ARCH_STEAMROLLER, @@ -900,6 +948,9 @@ typedef enum BLIS_ARCH_BULLDOZER, // ARM + BLIS_ARCH_ARMSVE, + BLIS_ARCH_A64FX, + BLIS_ARCH_FIRESTORM, BLIS_ARCH_THUNDERX2, BLIS_ARCH_CORTEXA57, BLIS_ARCH_CORTEXA53, @@ -907,26 +958,29 @@ typedef enum BLIS_ARCH_CORTEXA9, // IBM/Power + BLIS_ARCH_POWER10, BLIS_ARCH_POWER9, BLIS_ARCH_POWER7, BLIS_ARCH_BGQ, // Generic architecture/configuration - BLIS_ARCH_GENERIC + BLIS_ARCH_GENERIC, -} arch_t; + // The total number of defined architectures. This must be last in the + // list of enums since its definition assumes that the previous enum + // value (BLIS_ARCH_GENERIC) is given index num_archs-1. + BLIS_NUM_ARCHS -#define BLIS_NUM_ARCHS 20 +} arch_t; // // -- BLIS misc. structure types ----------------------------------------------- // -// These headers must be included here (or earlier) because definitions they -// provide are needed in the pool_t and related structs. +// This header must be included here (or earlier) because definitions it +// provides are needed in the pool_t and related structs. #include "bli_pthread.h" -#include "bli_malloc.h" // -- Pool block type -- @@ -950,6 +1004,7 @@ typedef struct siz_t block_size; siz_t align_size; + siz_t offset_size; malloc_ft malloc_fp; free_ft free_fp; @@ -983,7 +1038,7 @@ typedef struct // -- packing block allocator: Locked set of pools type -- -typedef struct membrk_s +typedef struct pba_s { pool_t pools[3]; bli_pthread_mutex_t mutex; @@ -993,7 +1048,7 @@ typedef struct membrk_s malloc_ft malloc_fp; free_ft free_fp; -} membrk_t; +} pba_t; // -- Memory object type -- @@ -1014,7 +1069,7 @@ struct cntl_s // Basic fields (usually required). opid_t family; bszid_t bszid; - void* var_func; + void_fp var_func; struct cntl_s* sub_prenode; struct cntl_s* sub_node; @@ -1047,7 +1102,7 @@ typedef struct blksz_s typedef struct func_s { // Kernel function address. - void* ptr[BLIS_NUM_FP_TYPES]; + void_fp ptr[BLIS_NUM_FP_TYPES]; } func_t; @@ -1056,7 +1111,7 @@ typedef struct func_s typedef struct mbool_s { - bool_t v[BLIS_NUM_FP_TYPES]; + bool v[BLIS_NUM_FP_TYPES]; } mbool_t; @@ -1082,9 +1137,20 @@ typedef struct inc_t is_a; inc_t is_b; + // The panel strides of A and B. + // NOTE: These are only used in situations where iteration over the + // micropanels takes place in part within the kernel code (e.g. sup + // millikernels). + inc_t ps_a; + inc_t ps_b; + // The type to convert to on output. //num_t dt_on_output; + // (Virtual) microkernel address and additional parameters. + void_fp ukr; + void* params; + } auxinfo_t; @@ -1107,6 +1173,33 @@ typedef struct constdata_s // -- BLIS object type definitions --------------------------------------------- // +// Forward declarations for function pointer types +struct obj_s; +struct cntx_s; +struct rntm_s; +struct thrinfo_s; + +typedef void (*obj_pack_fn_t) + ( + struct obj_s* a, + struct obj_s* ap, + struct cntx_s* cntx, + struct rntm_s* rntm, + struct cntl_s* cntl, + struct thrinfo_s* thread + ); + +typedef void (*obj_ker_fn_t) + ( + struct obj_s* a, + struct obj_s* b, + struct obj_s* c, + struct cntx_s* cntx, + struct rntm_s* rntm, + struct cntl_s* cntl, + struct thrinfo_s* thread + ); + typedef struct obj_s { // Basic fields @@ -1136,72 +1229,197 @@ typedef struct obj_s // usually MR or NR) dim_t m_panel; // m dimension of a "full" panel dim_t n_panel; // n dimension of a "full" panel + + // User-customizable fields + obj_pack_fn_t pack_fn; + void* pack_params; + obj_ker_fn_t ker_fn; + void* ker_params; + } obj_t; +// Pre-initializors. Things that must be set afterwards: +// - root object pointer +// - info bitfields: dt, target_dt, exec_dt, comp_dt +// - info2 bitfields: scalar_dt +// - elem_size +// - dims, strides +// - buffer +// - internal scalar buffer (must always set imaginary component) + +#define BLIS_OBJECT_INITIALIZER \ +{ \ + .root = NULL, \ +\ + .off = { 0, 0 }, \ + .dim = { 0, 0 }, \ + .diag_off = 0, \ +\ + .info = 0x0 | BLIS_BITVAL_DENSE | \ + BLIS_BITVAL_GENERAL, \ + .info2 = 0x0, \ + .elem_size = sizeof( float ), /* this is changed later. */ \ +\ + .buffer = NULL, \ + .rs = 0, \ + .cs = 0, \ + .is = 1, \ +\ + .scalar = { 0.0, 0.0 }, \ +\ + .m_padded = 0, \ + .n_padded = 0, \ + .ps = 0, \ + .pd = 0, \ + .m_panel = 0, \ + .n_panel = 0, \ +\ + .pack_fn = NULL, \ + .pack_params = NULL, \ + .ker_fn = NULL, \ + .ker_params = NULL \ +} + +#define BLIS_OBJECT_INITIALIZER_1X1 \ +{ \ + .root = NULL, \ +\ + .off = { 0, 0 }, \ + .dim = { 1, 1 }, \ + .diag_off = 0, \ +\ + .info = 0x0 | BLIS_BITVAL_DENSE | \ + BLIS_BITVAL_GENERAL, \ + .info2 = 0x0, \ + .elem_size = sizeof( float ), /* this is changed later. */ \ +\ + .buffer = NULL, \ + .rs = 0, \ + .cs = 0, \ + .is = 1, \ +\ + .scalar = { 0.0, 0.0 }, \ +\ + .m_padded = 0, \ + .n_padded = 0, \ + .ps = 0, \ + .pd = 0, \ + .m_panel = 0, \ + .n_panel = 0, \ +\ + .pack_fn = NULL, \ + .pack_params = NULL, \ + .ker_fn = NULL, \ + .ker_params = NULL \ +} + // Define these macros here since they must be updated if contents of // obj_t changes. -static void bli_obj_init_full_shallow_copy_of( obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_init_full_shallow_copy_of( obj_t* a, obj_t* b ) { - b->root = a->root; - - b->off[0] = a->off[0]; - b->off[1] = a->off[1]; - b->dim[0] = a->dim[0]; - b->dim[1] = a->dim[1]; - b->diag_off = a->diag_off; - - b->info = a->info; - b->info2 = a->info2; - b->elem_size = a->elem_size; - - b->buffer = a->buffer; - b->rs = a->rs; - b->cs = a->cs; - b->is = a->is; - - b->scalar = a->scalar; - - //b->pack_mem = a->pack_mem; - b->m_padded = a->m_padded; - b->n_padded = a->n_padded; - b->ps = a->ps; - b->pd = a->pd; - b->m_panel = a->m_panel; - b->n_panel = a->n_panel; + b->root = a->root; + + b->off[0] = a->off[0]; + b->off[1] = a->off[1]; + b->dim[0] = a->dim[0]; + b->dim[1] = a->dim[1]; + b->diag_off = a->diag_off; + + b->info = a->info; + b->info2 = a->info2; + b->elem_size = a->elem_size; + + b->buffer = a->buffer; + b->rs = a->rs; + b->cs = a->cs; + b->is = a->is; + + b->scalar = a->scalar; + + //b->pack_mem = a->pack_mem; + b->m_padded = a->m_padded; + b->n_padded = a->n_padded; + b->ps = a->ps; + b->pd = a->pd; + b->m_panel = a->m_panel; + b->n_panel = a->n_panel; + + b->pack_fn = a->pack_fn; + b->pack_params = a->pack_params; + b->ker_fn = a->ker_fn; + b->ker_params = a->ker_params; } -static void bli_obj_init_subpart_from( obj_t* a, obj_t* b ) +BLIS_INLINE void bli_obj_init_subpart_from( obj_t* a, obj_t* b ) { - b->root = a->root; + b->root = a->root; - b->off[0] = a->off[0]; - b->off[1] = a->off[1]; + b->off[0] = a->off[0]; + b->off[1] = a->off[1]; // Avoid copying m and n since they will be overwritten. - //b->dim[0] = a->dim[0]; - //b->dim[1] = a->dim[1]; - b->diag_off = a->diag_off; + //b->dim[0] = a->dim[0]; + //b->dim[1] = a->dim[1]; + b->diag_off = a->diag_off; - b->info = a->info; - b->info2 = a->info2; - b->elem_size = a->elem_size; + b->info = a->info; + b->info2 = a->info2; + b->elem_size = a->elem_size; - b->buffer = a->buffer; - b->rs = a->rs; - b->cs = a->cs; - b->is = a->is; + b->buffer = a->buffer; + b->rs = a->rs; + b->cs = a->cs; + b->is = a->is; - b->scalar = a->scalar; + b->scalar = a->scalar; // Avoid copying pack_mem entry. // FGVZ: You should probably make sure this is right. - //b->pack_mem = a->pack_mem; - b->m_padded = a->m_padded; - b->n_padded = a->n_padded; - b->ps = a->ps; - b->pd = a->pd; - b->m_panel = a->m_panel; - b->n_panel = a->n_panel; + //b->pack_mem = a->pack_mem; + b->m_padded = a->m_padded; + b->n_padded = a->n_padded; + b->ps = a->ps; + b->pd = a->pd; + b->m_panel = a->m_panel; + b->n_panel = a->n_panel; + + b->pack_fn = a->pack_fn; + b->pack_params = a->pack_params; + b->ker_fn = a->ker_fn; + b->ker_params = a->ker_params; +} + +// Initializors for global scalar constants. +// NOTE: These must remain cpp macros since they are initializor +// expressions, not functions. + +#define bli_obj_init_const( buffer0 ) \ +{ \ + .root = NULL, \ +\ + .off = { 0, 0 }, \ + .dim = { 1, 1 }, \ + .diag_off = 0, \ +\ + .info = 0x0 | BLIS_BITVAL_CONST_TYPE | \ + BLIS_BITVAL_DENSE | \ + BLIS_BITVAL_GENERAL, \ + .info2 = 0x0, \ + .elem_size = sizeof( constdata_t ), \ +\ + .buffer = buffer0, \ + .rs = 1, \ + .cs = 1, \ + .is = 1 \ +} + +#define bli_obj_init_constdata( val ) \ +{ \ + .s = ( float )val, \ + .d = ( double )val, \ + .c = { .real = ( float )val, .imag = 0.0f }, \ + .z = { .real = ( double )val, .imag = 0.0 }, \ + .i = ( gint_t )val, \ } @@ -1216,6 +1434,12 @@ typedef struct cntx_s func_t l3_nat_ukrs[ BLIS_NUM_LEVEL3_UKRS ]; mbool_t l3_nat_ukrs_prefs[ BLIS_NUM_LEVEL3_UKRS ]; + blksz_t l3_sup_thresh[ BLIS_NUM_THRESH ]; + void* l3_sup_handlers[ BLIS_NUM_LEVEL3_OPS ]; + blksz_t l3_sup_blkszs[ BLIS_NUM_BLKSZS ]; + func_t l3_sup_kers[ BLIS_NUM_3OP_RC_COMBOS ]; + mbool_t l3_sup_kers_prefs[ BLIS_NUM_3OP_RC_COMBOS ]; + func_t l1f_kers[ BLIS_NUM_LEVEL1F_KERS ]; func_t l1v_kers[ BLIS_NUM_LEVEL1V_KERS ]; @@ -1223,20 +1447,25 @@ typedef struct cntx_s func_t unpackm_kers[ BLIS_NUM_UNPACKM_KERS ]; ind_t method; - pack_t schema_a_block; - pack_t schema_b_panel; - pack_t schema_c_panel; } cntx_t; // -- Runtime type -- +// NOTE: The order of these fields must be kept consistent with the definition +// of the BLIS_RNTM_INITIALIZER macro in bli_rntm.h. + typedef struct rntm_s { // "External" fields: these may be queried by the end-user. + bool auto_factor; + dim_t num_threads; dim_t thrloop[ BLIS_NUM_LOOPS ]; + bool pack_a; // enable/disable packing of left-hand matrix A. + bool pack_b; // enable/disable packing of right-hand matrix B. + bool l3_sup; // enable/disable small matrix handling in level-3 ops. // "Internal" fields: these should not be exposed to the end-user. @@ -1244,7 +1473,7 @@ typedef struct rntm_s pool_t* sba_pool; // The packing block allocator, which is attached in the l3 thread decorator. - membrk_t* membrk; + pba_t* pba; } rntm_t; @@ -1309,13 +1538,13 @@ typedef enum BLIS_INVALID_COL_STRIDE = ( -51), BLIS_INVALID_DIM_STRIDE_COMBINATION = ( -52), - // Structure-specific errors + // Structure-specific errors BLIS_EXPECTED_GENERAL_OBJECT = ( -60), BLIS_EXPECTED_HERMITIAN_OBJECT = ( -61), BLIS_EXPECTED_SYMMETRIC_OBJECT = ( -62), BLIS_EXPECTED_TRIANGULAR_OBJECT = ( -63), - // Storage-specific errors + // Storage-specific errors BLIS_EXPECTED_UPPER_OR_LOWER_OBJECT = ( -70), // Partitioning-specific errors @@ -1329,7 +1558,7 @@ typedef enum // Packing-specific errors BLIS_PACK_SCHEMA_NOT_SUPPORTED_FOR_UNPACK = (-100), - // Buffer-specific errors + // Buffer-specific errors BLIS_EXPECTED_NONNULL_OBJECT_BUFFER = (-110), // Memory errors @@ -1347,6 +1576,7 @@ typedef enum // Architecture-related errors BLIS_INVALID_ARCH_ID = (-150), + BLIS_UNINITIALIZED_GKS_CNTX = (-151), // Blocksize-related errors BLIS_MC_DEF_NONMULTIPLE_OF_MR = (-160), diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index d329a2c3a2..b470d320d9 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -643,6 +644,7 @@ #define XOR(_0, _1) INSTR_(xor, _0, _1) #define ADD(_0, _1) INSTR_(add, _0, _1) #define SUB(_0, _1) INSTR_(sub, _0, _1) +#define IMUL(_0, _1) INSTR_(imul, _0, _1) #define SAL(...) INSTR_(sal, __VA_ARGS__) #define SAR(...) INSTR_(sar, __VA_ARGS__) #define SHLX(_0, _1, _2) INSTR_(shlx, _0, _1, _2) @@ -656,6 +658,7 @@ #define xor(_0, _1) XOR(_0, _1) #define add(_0, _1) ADD(_0, _1) #define sub(_0, _1) SUB(_0, _1) +#define imul(_0, _1) IMUL(_0, _1) #define sal(...) SAL(__VA_ARGS__) #define sar(...) SAR(__VA_ARGS__) #define shlx(_0, _1, _2) SHLX(_0, _1, _2) @@ -780,13 +783,13 @@ #define VPBROADCASTQ(_0, _1) INSTR_(vpbroadcastq, _0, _1) #define VBROADCASTF128(_0, _1) INSTR_(vbroadcastf128, _0, _1) #define VBROADCASTF64X4(_0, _1) INSTR_(vbroadcastf64x4, _0, _1) -#define VGATHERDPS(_0, _1) INSTR_(vgatherdps, _0, _1) +#define VGATHERDPS(...) INSTR_(vgatherdps, __VA_ARGS__) #define VSCATTERDPS(_0, _1) INSTR_(vscatterdps, _0, _1) -#define VGATHERDPD(_0, _1) INSTR_(vgatherdpd, _0, _1) +#define VGATHERDPD(...) INSTR_(vgatherdpd, __VA_ARGS__) #define VSCATTERDPD(_0, _1) INSTR_(vscatterdpd, _0, _1) -#define VGATHERQPS(_0, _1) INSTR_(vgatherqps, _0, _1) +#define VGATHERQPS(...) INSTR_(vgatherqps, __VA_ARGS__) #define VSCATTERQPS(_0, _1) INSTR_(vscatterqps, _0, _1) -#define VGATHERQPD(_0, _1) INSTR_(vgatherqpd, _0, _1) +#define VGATHERQPD(...) INSTR_(vgatherqpd, __VA_ARGS__) #define VSCATTERQPD(_0, _1) INSTR_(vscatterqpd, _0, _1) #define vmovddup(_0, _1) VMOVDDUP(_0, _1) @@ -809,19 +812,41 @@ #define vmovdqa64(_0, _1) VMOVDQA64(_0, _1) #define vbroadcastss(_0, _1) VBROADCASTSS(_0, _1) #define vbroadcastsd(_0, _1) VBROADCASTSD(_0, _1) -#define vpbraodcastd(_0, _1) VPBROADCASTD(_0, _1) +#define vpbroadcastd(_0, _1) VPBROADCASTD(_0, _1) #define vpbroadcastq(_0, _1) VPBROADCASTQ(_0, _1) #define vbroadcastf128(_0, _1) VBROADCASTF128(_0, _1) #define vbroadcastf64x4(_0, _1) VBROADCASTF64X4(_0, _1) -#define vgatherdps(_0, _1) VGATHERDPS(_0, _1) +#define vgatherdps(...) VGATHERDPS(__VA_ARGS__) #define vscatterdps(_0, _1) VSCATTERDPS(_0, _1) -#define vgatherdpd(_0, _1) VGATHERDPD(_0, _1) +#define vgatherdpd(...) VGATHERDPD(__VA_ARGS__) #define vscatterdpd(_0, _1) VSCATTERDPD(_0, _1) -#define vgatherqps(_0, _1) VGATHERQPS(_0, _1) +#define vgatherqps(...) VGATHERQPS(__VA_ARGS__) #define vscatterqps(_0, _1) VSCATTERQPS(_0, _1) -#define vgatherqpd(_0, _1) VGATHERQPD(_0, _1) +#define vgatherqpd(...) VGATHERQPD(__VA_ARGS__) #define vscatterqpd(_0, _1) VSCATTERQPD(_0, _1) +// Vector comparisons + +#define VPCMPEQB(_0, _1, _2) INSTR_(vpcmpeqb, _0, _1, _2) +#define VPCMPEQW(_0, _1, _2) INSTR_(vpcmpeqw, _0, _1, _2) +#define VPCMPEQD(_0, _1, _2) INSTR_(vpcmpeqd, _0, _1, _2) + +#define vpcmpeqb(_0, _1, _2) VPCMPEQB(_0, _1, _2) +#define vpcmpeqw(_0, _1, _2) VPCMPEQW(_0, _1, _2) +#define vpcmpeqd(_0, _1, _2) VPCMPEQD(_0, _1, _2) + +// Vector integer math + +#define VPADDB(_0, _1, _2) INSTR_(vpaddb, _0, _1, _2) +#define VPADDW(_0, _1, _2) INSTR_(vpaddw, _0, _1, _2) +#define VPADDD(_0, _1, _2) INSTR_(vpaddd, _0, _1, _2) +#define VPADDQ(_0, _1, _2) INSTR_(vpaddq, _0, _1, _2) + +#define vpaddb(_0, _1, _2) VPADDB(_0, _1, _2) +#define vpaddw(_0, _1, _2) VPADDW(_0, _1, _2) +#define vpaddd(_0, _1, _2) VPADDD(_0, _1, _2) +#define vpaddq(_0, _1, _2) VPADDQ(_0, _1, _2) + // Vector math #define ADDPS(_0, _1) INSTR_(addps, _0, _1) @@ -830,8 +855,11 @@ #define SUBPD(_0, _1) INSTR_(subpd, _0, _1) #define MULPS(_0, _1) INSTR_(mulps, _0, _1) #define MULPD(_0, _1) INSTR_(mulpd, _0, _1) +#define DIVPS(_0, _1) INSTR_(divps, _0, _1) +#define DIVPD(_0, _1) INSTR_(divpd, _0, _1) #define XORPS(_0, _1) INSTR_(xorps, _0, _1) #define XORPD(_0, _1) INSTR_(xorpd, _0, _1) + #define UCOMISS(_0, _1) INSTR_(ucomiss, _0, _1) #define UCOMISD(_0, _1) INSTR_(ucomisd, _0, _1) #define COMISS(_0, _1) INSTR_(comiss, _0, _1) @@ -843,8 +871,11 @@ #define subpd(_0, _1) SUBPD(_0, _1) #define mulps(_0, _1) MULPS(_0, _1) #define mulpd(_0, _1) MULPD(_0, _1) +#define divps(_0, _1) DIVPS(_0, _1) +#define divpd(_0, _1) DIVPD(_0, _1) #define xorps(_0, _1) XORPS(_0, _1) #define xorpd(_0, _1) XORPD(_0, _1) + #define ucomiss(_0, _1) UCOMISS(_0, _1) #define ucomisd(_0, _1) UCOMISD(_0, _1) #define cmoiss(_0, _1) COMISS(_0, _1) @@ -852,10 +883,10 @@ #define VADDSUBPS(_0, _1, _2) INSTR_(vaddsubps, _0, _1, _2) #define VADDSUBPD(_0, _1, _2) INSTR_(vaddsubpd, _0, _1, _2) -#define VUCOMISS(_0, _1) INSTR_(vucomiss, _0, _1) -#define VUCOMISD(_0, _1) INSTR_(vucomisd, _0, _1) -#define VCOMISS(_0, _1) INSTR_(vcomiss, _0, _1) -#define VCOMISD(_0, _1) INSTR_(vcomisd, _0, _1) +#define VHADDPD(_0, _1, _2) INSTR_(vhaddpd, _0, _1, _2) +#define VHADDPS(_0, _1, _2) INSTR_(vhaddps, _0, _1, _2) +#define VHSUBPD(_0, _1, _2) INSTR_(vhsubpd, _0, _1, _2) +#define VHSUBPS(_0, _1, _2) INSTR_(vhsubps, _0, _1, _2) #define VADDPS(_0, _1, _2) INSTR_(vaddps, _0, _1, _2) #define VADDPD(_0, _1, _2) INSTR_(vaddpd, _0, _1, _2) #define VSUBPS(_0, _1, _2) INSTR_(vsubps, _0, _1, _2) @@ -864,6 +895,10 @@ #define VMULSD(_0, _1, _2) INSTR_(vmulsd, _0, _1, _2) #define VMULPS(_0, _1, _2) INSTR_(vmulps, _0, _1, _2) #define VMULPD(_0, _1, _2) INSTR_(vmulpd, _0, _1, _2) +#define VDIVSS(_0, _1, _2) INSTR_(vdivss, _0, _1, _2) +#define VDIVSD(_0, _1, _2) INSTR_(vdivsd, _0, _1, _2) +#define VDIVPS(_0, _1, _2) INSTR_(vdivps, _0, _1, _2) +#define VDIVPD(_0, _1, _2) INSTR_(vdivpd, _0, _1, _2) #define VPMULLD(_0, _1, _2) INSTR_(vpmulld, _0, _1, _2) #define VPMULLQ(_0, _1, _2) INSTR_(vpmullq, _0, _1, _2) #define VPADDD(_0, _1, _2) INSTR_(vpaddd, _0, _1, _2) @@ -871,6 +906,12 @@ #define VXORPS(_0, _1, _2) INSTR_(vxorps, _0, _1, _2) #define VXORPD(_0, _1, _2) INSTR_(vxorpd, _0, _1, _2) #define VPXORD(_0, _1, _2) INSTR_(vpxord, _0, _1, _2) + +#define VUCOMISS(_0, _1) INSTR_(vucomiss, _0, _1) +#define VUCOMISD(_0, _1) INSTR_(vucomisd, _0, _1) +#define VCOMISS(_0, _1) INSTR_(vcomiss, _0, _1) +#define VCOMISD(_0, _1) INSTR_(vcomisd, _0, _1) + #define VFMADD132SS(_0, _1, _2) INSTR_(vfmadd132ss, _0, _1, _2) #define VFMADD213SS(_0, _1, _2) INSTR_(vfmadd213ss, _0, _1, _2) #define VFMADD231SS(_0, _1, _2) INSTR_(vfmadd231ss, _0, _1, _2) @@ -974,10 +1015,10 @@ #define vaddsubps(_0, _1, _2) VADDSUBPS(_0, _1, _2) #define vaddsubpd(_0, _1, _2) VADDSUBPD(_0, _1, _2) -#define vucomiss(_0, _1) VUCOMISS(_0, _1) -#define vucomisd(_0, _1) VUCOMISD(_0, _1) -#define vcomiss(_0, _1) VCOMISS(_0, _1) -#define vcomisd(_0, _1) VCOMISD(_0, _1) +#define vhaddpd(_0, _1, _2) VHADDPD(_0, _1, _2) +#define vhaddps(_0, _1, _2) VHADDPS(_0, _1, _2) +#define vhsubpd(_0, _1, _2) VHSUBPD(_0, _1, _2) +#define vhsubps(_0, _1, _2) VHSUBPS(_0, _1, _2) #define vaddps(_0, _1, _2) VADDPS(_0, _1, _2) #define vaddpd(_0, _1, _2) VADDPD(_0, _1, _2) #define vsubps(_0, _1, _2) VSUBPS(_0, _1, _2) @@ -986,6 +1027,10 @@ #define vmulps(_0, _1, _2) VMULPS(_0, _1, _2) #define vmulsd(_0, _1, _2) VMULSD(_0, _1, _2) #define vmulpd(_0, _1, _2) VMULPD(_0, _1, _2) +#define vdivss(_0, _1, _2) VDIVSS(_0, _1, _2) +#define vdivps(_0, _1, _2) VDIVPS(_0, _1, _2) +#define vdivsd(_0, _1, _2) VDIVSD(_0, _1, _2) +#define vdivpd(_0, _1, _2) VDIVPD(_0, _1, _2) #define vpmulld(_0, _1, _2) VPMULLD(_0, _1, _2) #define vpmullq(_0, _1, _2) VPMULLQ(_0, _1, _2) #define vpaddd(_0, _1, _2) VPADDD(_0, _1, _2) @@ -993,6 +1038,12 @@ #define vxorps(_0, _1, _2) VXORPS(_0, _1, _2) #define vxorpd(_0, _1, _2) VXORPD(_0, _1, _2) #define vpxord(_0, _1, _2) VPXORD(_0, _1, _2) + +#define vucomiss(_0, _1) VUCOMISS(_0, _1) +#define vucomisd(_0, _1) VUCOMISD(_0, _1) +#define vcomiss(_0, _1) VCOMISS(_0, _1) +#define vcomisd(_0, _1) VCOMISD(_0, _1) + #define vfmadd132ss(_0, _1, _2) VFMADD132SS(_0, _1, _2) #define vfmadd213ss(_0, _1, _2) VFMADD213SS(_0, _1, _2) #define vfmadd231ss(_0, _1, _2) VFMADD231SS(_0, _1, _2) diff --git a/frame/include/bli_xapi_undef.h b/frame/include/bli_xapi_undef.h new file mode 100644 index 0000000000..3d13051e51 --- /dev/null +++ b/frame/include/bli_xapi_undef.h @@ -0,0 +1,57 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// This file un-defines macros used to allow the _oapi.c and _tapi.c files to +// produce object and typed APIs that omit or contain expert parameters. + +// Un-define all macros that allow the source code to determine which interface +// (basic or expert) we are compiling. +#undef BLIS_OAPI_BASIC +#undef BLIS_OAPI_EXPERT +#undef BLIS_TAPI_BASIC +#undef BLIS_TAPI_EXPERT + +// Un-define the macro to omit or add the function name suffix (in function +// definitions). +#undef EX_SUF + +// Un-define the macro to omit or add expert arguments from function signatures +// and prototypes. +#undef BLIS_OAPI_EX_PARAMS +#undef BLIS_TAPI_EX_PARAMS + +// Un-define the macro to omit or add local expert variables. +#undef BLIS_OAPI_EX_DECLS +#undef BLIS_TAPI_EX_DECLS + diff --git a/frame/include/blis.h b/frame/include/blis.h index 02539bea9e..98ebee878d 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -48,15 +48,24 @@ extern "C" { // NOTE: PLEASE DON'T CHANGE THE ORDER IN WHICH HEADERS ARE INCLUDED UNLESS // YOU ARE SURE THAT IT DOESN'T BREAK INTER-HEADER MACRO DEPENDENCIES. -// -- System headers -- -// NOTE: This header must be included before bli_config_macro_defs.h. +// -- configure definitions -- + +// NOTE: bli_config.h header must be included before any BLIS header. +// It is bootstrapped by ./configure and does not depend on later +// headers. Moreover, these configuration variables are necessary to change +// some default behaviors (e.g. disable OS-detection in bli_system.h in case +// of --disable-system). +#include "bli_config.h" +// -- System and language-related headers -- + +// NOTE: bli_system.h header must be included before bli_config_macro_defs.h. #include "bli_system.h" +#include "bli_lang_defs.h" -// -- configure definitions -- +// -- configure default definitions -- -#include "bli_config.h" #include "bli_config_macro_defs.h" @@ -88,6 +97,7 @@ extern "C" { #include "bli_l1f_ker_prot.h" #include "bli_l1m_ker_prot.h" #include "bli_l3_ukr_prot.h" +#include "bli_l3_sup_ker_prot.h" #include "bli_arch_config_pre.h" #include "bli_arch_config.h" @@ -98,6 +108,7 @@ extern "C" { // -- Base operation prototypes -- #include "bli_init.h" +#include "bli_malloc.h" #include "bli_const.h" #include "bli_obj.h" #include "bli_obj_scalar.h" @@ -108,7 +119,7 @@ extern "C" { #include "bli_rntm.h" #include "bli_gks.h" #include "bli_ind.h" -#include "bli_membrk.h" +#include "bli_pba.h" #include "bli_pool.h" #include "bli_array.h" #include "bli_apool.h" @@ -128,11 +139,14 @@ extern "C" { #include "bli_getopt.h" #include "bli_opid.h" #include "bli_cntl.h" +#include "bli_env.h" +#include "bli_pack.h" #include "bli_info.h" #include "bli_arch.h" #include "bli_cpuid.h" #include "bli_string.h" -#include "bli_setgetij.h" +#include "bli_setgetijm.h" +#include "bli_setgetijv.h" #include "bli_setri.h" #include "bli_castm.h" @@ -182,6 +196,14 @@ extern "C" { #include "bli_util.h" +// -- addon definitions -- + +// NOTE: These definitions should not be included much earlier since an addon +// may wish to utilize other types and definitions provided by BLIS. + +#include "bli_addon.h" + + // -- sandbox implementation -- #include "bli_sbox.h" diff --git a/frame/include/level0/1m/bli_scal21ms_mxn.h b/frame/include/level0/1m/bli_scal21ms_mxn.h index 4bb60279da..9a824fbd5f 100644 --- a/frame/include/level0/1m/bli_scal21ms_mxn.h +++ b/frame/include/level0/1m/bli_scal21ms_mxn.h @@ -37,7 +37,7 @@ // scal21ms_mxn -static void bli_cscal21ms_mxn +BLIS_INLINE void bli_cscal21ms_mxn ( const pack_t schema, const conj_t conjx, @@ -118,7 +118,7 @@ static void bli_cscal21ms_mxn } } -static void bli_zscal21ms_mxn +BLIS_INLINE void bli_zscal21ms_mxn ( const pack_t schema, const conj_t conjx, diff --git a/frame/include/level0/1m/bli_set1ms_mxn.h b/frame/include/level0/1m/bli_set1ms_mxn.h index a0c85a5921..f7d492c234 100644 --- a/frame/include/level0/1m/bli_set1ms_mxn.h +++ b/frame/include/level0/1m/bli_set1ms_mxn.h @@ -49,7 +49,7 @@ components of packm. */ \ } -static void bli_cset1ms_mxn +BLIS_INLINE void bli_cset1ms_mxn ( const pack_t schema, const dim_t offm, @@ -120,7 +120,7 @@ static void bli_cset1ms_mxn } } -static void bli_zset1ms_mxn +BLIS_INLINE void bli_zset1ms_mxn ( const pack_t schema, const dim_t offm, diff --git a/frame/include/level0/bb/bli_bcastbbs_mxn.h b/frame/include/level0/bb/bli_bcastbbs_mxn.h new file mode 100644 index 0000000000..84ca4fdc13 --- /dev/null +++ b/frame/include/level0/bb/bli_bcastbbs_mxn.h @@ -0,0 +1,74 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_BCASTBBS_MXN_H +#define BLIS_BCASTBBS_MXN_H + +// bcastbbs_mxn + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +BLIS_INLINE void PASTEMAC(ch,opname) \ + ( \ + const dim_t m, \ + const dim_t n, \ + ctype* restrict y, const inc_t incy, const inc_t ldy \ + ) \ +{ \ + /* Assume that the duplication factor is the column stride of y. */ \ + const dim_t d = ldy; \ + const dim_t ds_y = 1; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict yi = y + i*incy; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict yij = yi + j*ldy; \ +\ + for ( dim_t p = 1; p < d; ++p ) \ + { \ + ctype* restrict yijd = yij + p*ds_y; \ +\ + PASTEMAC(ch,copys)( *yij, *yijd ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( bcastbbs_mxn ) + +#endif diff --git a/frame/include/level0/bb/bli_scal2bbs_mxn.h b/frame/include/level0/bb/bli_scal2bbs_mxn.h new file mode 100644 index 0000000000..9d0325b5e3 --- /dev/null +++ b/frame/include/level0/bb/bli_scal2bbs_mxn.h @@ -0,0 +1,204 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SCAL2BBS_MXN_H +#define BLIS_SCAL2BBS_MXN_H + +// scal2bbs_mxn + +#undef GENTFUNCRO +#define GENTFUNCRO( ctype, ch, opname ) \ +\ +BLIS_INLINE void PASTEMAC(ch,opname) \ + ( \ + const conj_t conjx, \ + const dim_t m, \ + const dim_t n, \ + ctype* restrict alpha, \ + ctype* restrict x, const inc_t incx, const inc_t ldx, \ + ctype* restrict y, const inc_t incy, const inc_t ldy \ + ) \ +{ \ + /* Assume that the duplication factor is the row stride of y. */ \ + const dim_t d = incy; \ + const dim_t ds_y = 1; \ +\ + if ( bli_is_conj( conjx ) ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict xj = x + j*ldx; \ + ctype* restrict yj = y + j*ldy; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict xij = xj + i*incx; \ + ctype* restrict yij = yj + i*incy; \ +\ + PASTEMAC(ch,scal2js)( *alpha, *xij, *yij ); \ +\ + for ( dim_t p = 1; p < d; ++p ) \ + { \ + ctype* restrict yijd = yij + p*ds_y; \ +\ + PASTEMAC(ch,copys)( *yij, *yijd ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_noconj( conjx ) ) */ \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict xj = x + j*ldx; \ + ctype* restrict yj = y + j*ldy; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict xij = xj + i*incx; \ + ctype* restrict yij = yj + i*incy; \ +\ + PASTEMAC(ch,scal2s)( *alpha, *xij, *yij ); \ +\ + for ( dim_t p = 1; p < d; ++p ) \ + { \ + ctype* restrict yijd = yij + p*ds_y; \ +\ + PASTEMAC(ch,copys)( *yij, *yijd ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNCRO_BASIC0( scal2bbs_mxn ) + + +#undef GENTFUNCCO +#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ +\ +BLIS_INLINE void PASTEMAC(ch,opname) \ + ( \ + const conj_t conjx, \ + const dim_t m, \ + const dim_t n, \ + ctype* restrict alpha, \ + ctype* restrict x, const inc_t incx, const inc_t ldx, \ + ctype* restrict y, const inc_t incy, const inc_t ldy \ + ) \ +{ \ + /* Assume that the duplication factor is the row stride of y. */ \ + const dim_t d = incy; \ + const dim_t ds_y = 1; \ +\ + const inc_t incx2 = 2 * incx; \ + const inc_t ldx2 = 2 * ldx; \ +\ + const inc_t incy2 = 2 * incy; \ + const inc_t ldy2 = 2 * ldy; \ +\ + ctype_r* restrict alpha_r = ( ctype_r* )alpha; \ + ctype_r* restrict alpha_i = ( ctype_r* )alpha + 1; \ + ctype_r* restrict chi_r = ( ctype_r* )x; \ + ctype_r* restrict chi_i = ( ctype_r* )x + 1; \ + ctype_r* restrict psi_r = ( ctype_r* )y; \ + ctype_r* restrict psi_i = ( ctype_r* )y + 1*d; \ +\ + if ( bli_is_conj( conjx ) ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype_r* restrict chij_r = chi_r + j*ldx2; \ + ctype_r* restrict chij_i = chi_i + j*ldx2; \ + ctype_r* restrict psij_r = psi_r + j*ldy2; \ + ctype_r* restrict psij_i = psi_i + j*ldy2; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype_r* restrict chiij_r = chij_r + i*incx2; \ + ctype_r* restrict chiij_i = chij_i + i*incx2; \ + ctype_r* restrict psiij_r = psij_r + i*incy2; \ + ctype_r* restrict psiij_i = psij_i + i*incy2; \ +\ + PASTEMAC(ch,scal2jris)( *alpha_r, *alpha_i, \ + *chiij_r, *chiij_i, \ + *psiij_r, *psiij_i ); \ +\ + for ( dim_t p = 1; p < d; ++p ) \ + { \ + ctype_r* restrict psiijd_r = psiij_r + p*ds_y; \ + ctype_r* restrict psiijd_i = psiij_i + p*ds_y; \ +\ + PASTEMAC(ch,copyris)( *psiij_r, *psiij_i, \ + *psiijd_r, *psiijd_i ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_noconj( conjx ) ) */ \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype_r* restrict chij_r = chi_r + j*ldx2; \ + ctype_r* restrict chij_i = chi_i + j*ldx2; \ + ctype_r* restrict psij_r = psi_r + j*ldy2; \ + ctype_r* restrict psij_i = psi_i + j*ldy2; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype_r* restrict chiij_r = chij_r + i*incx2; \ + ctype_r* restrict chiij_i = chij_i + i*incx2; \ + ctype_r* restrict psiij_r = psij_r + i*incy2; \ + ctype_r* restrict psiij_i = psij_i + i*incy2; \ +\ + PASTEMAC(ch,scal2ris)( *alpha_r, *alpha_i, \ + *chiij_r, *chiij_i, \ + *psiij_r, *psiij_i ); \ +\ + for ( dim_t p = 1; p < d; ++p ) \ + { \ + ctype_r* restrict psiijd_r = psiij_r + p*ds_y; \ + ctype_r* restrict psiijd_i = psiij_i + p*ds_y; \ +\ + PASTEMAC(ch,copyris)( *psiij_r, *psiij_i, \ + *psiijd_r, *psiijd_i ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNCCO_BASIC0( scal2bbs_mxn ) + +#endif diff --git a/frame/1m/packm/bli_packm_blk_var1_md.h b/frame/include/level0/bb/bli_set0bbs_mxn.h similarity index 68% rename from frame/1m/packm/bli_packm_blk_var1_md.h rename to frame/include/level0/bb/bli_set0bbs_mxn.h index e6bf151d07..3a44883f42 100644 --- a/frame/1m/packm/bli_packm_blk_var1_md.h +++ b/frame/include/level0/bb/bli_set0bbs_mxn.h @@ -32,36 +32,43 @@ */ -void bli_packm_blk_var1_md - ( - obj_t* c, - obj_t* p, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* t - ); +#ifndef BLIS_SET0BBS_MXN_H +#define BLIS_SET0BBS_MXN_H +// set0bbs_mxn -#undef GENTPROT2 -#define GENTPROT2( ctype_c, ctype_p, chc, chp, varname ) \ +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(chc,chp,varname) \ +BLIS_INLINE void PASTEMAC(ch,opname) \ ( \ - trans_t transc, \ - pack_t schema, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - dim_t pd_p, inc_t ps_p, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ); + const dim_t m, \ + const dim_t n, \ + ctype* restrict y, const inc_t incy, const inc_t ldy \ + ) \ +{ \ + /* Assume that the duplication factor is the row stride of y. */ \ + const dim_t d = incy; \ + const dim_t ds_y = 1; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict yj = y + j*ldy; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict yij = yj + i*incy; \ +\ + for ( dim_t p = 0; p < d; ++p ) \ + { \ + ctype* restrict yijd = yij + p*ds_y; \ +\ + PASTEMAC(ch,set0s)( *yijd ); \ + } \ + } \ + } \ +} -INSERT_GENTPROT2_BASIC0( packm_blk_var1_md ) -INSERT_GENTPROT2_MIXDP0( packm_blk_var1_md ) +INSERT_GENTFUNC_BASIC0( set0bbs_mxn ) +#endif diff --git a/frame/include/level0/bli_adds_mxn.h b/frame/include/level0/bli_adds_mxn.h index d0cc1805bc..8a92a17a63 100644 --- a/frame/include/level0/bli_adds_mxn.h +++ b/frame/include/level0/bli_adds_mxn.h @@ -44,7 +44,7 @@ // xy = ?s -static void bli_ssadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_ssadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -71,7 +71,7 @@ static void bli_ssadds_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dsadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dsadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -98,7 +98,7 @@ static void bli_dsadds_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_csadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_csadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -125,7 +125,7 @@ static void bli_csadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zsadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zsadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -155,7 +155,7 @@ static void bli_zsadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, // xy = ?d -static void bli_sdadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sdadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -182,7 +182,7 @@ static void bli_sdadds_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_ddadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_ddadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -209,7 +209,7 @@ static void bli_ddadds_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cdadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cdadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -236,7 +236,7 @@ static void bli_cdadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zdadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zdadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -266,7 +266,7 @@ static void bli_zdadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, // xy = ?c -static void bli_scadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_scadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -293,7 +293,7 @@ static void bli_scadds_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dcadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dcadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -320,7 +320,7 @@ static void bli_dcadds_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_ccadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_ccadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -347,7 +347,7 @@ static void bli_ccadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zcadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zcadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -377,7 +377,7 @@ static void bli_zcadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, // xy = ?z -static void bli_szadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_szadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -404,7 +404,7 @@ static void bli_szadds_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dzadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dzadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -431,7 +431,7 @@ static void bli_dzadds_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_czadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_czadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -458,7 +458,7 @@ static void bli_czadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zzadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zzadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -488,22 +488,22 @@ static void bli_zzadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, -static void bli_sadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sadds_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_ssadds_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); } -static void bli_dadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dadds_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_ddadds_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); } -static void bli_cadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cadds_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_ccadds_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); } -static void bli_zadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zadds_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_zzadds_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); diff --git a/frame/include/level0/bli_copys_mxn.h b/frame/include/level0/bli_copys_mxn.h index dd50eb119e..a8ead1c307 100644 --- a/frame/include/level0/bli_copys_mxn.h +++ b/frame/include/level0/bli_copys_mxn.h @@ -43,7 +43,7 @@ // xy = ?s -static void bli_sscopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sscopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -70,7 +70,7 @@ static void bli_sscopys_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dscopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dscopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -97,7 +97,7 @@ static void bli_dscopys_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cscopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cscopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -124,7 +124,7 @@ static void bli_cscopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zscopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zscopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -154,7 +154,7 @@ static void bli_zscopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, // xy = ?d -static void bli_sdcopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sdcopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -181,7 +181,7 @@ static void bli_sdcopys_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_ddcopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_ddcopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -208,7 +208,7 @@ static void bli_ddcopys_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cdcopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cdcopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -235,7 +235,7 @@ static void bli_cdcopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zdcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zdcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -265,7 +265,7 @@ static void bli_zdcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, // xy = ?c -static void bli_sccopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sccopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -292,7 +292,7 @@ static void bli_sccopys_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dccopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dccopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -319,7 +319,7 @@ static void bli_dccopys_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cccopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cccopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -346,7 +346,7 @@ static void bli_cccopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zccopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zccopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -376,7 +376,7 @@ static void bli_zccopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, // xy = ?c -static void bli_szcopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_szcopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -403,7 +403,7 @@ static void bli_szcopys_mxn( const dim_t m, const dim_t n, float* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dzcopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dzcopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -430,7 +430,7 @@ static void bli_dzcopys_mxn( const dim_t m, const dim_t n, double* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_czcopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_czcopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -457,7 +457,7 @@ static void bli_czcopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zzcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zzcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { #ifdef BLIS_ENABLE_CR_CASES @@ -486,22 +486,22 @@ static void bli_zzcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, } -static void bli_scopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_scopys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_sscopys_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); } -static void bli_dcopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dcopys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_ddcopys_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); } -static void bli_ccopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_ccopys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_cccopys_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); } -static void bli_zcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zcopys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_zzcopys_mxn( m, n, x, rs_x, cs_x, y, rs_y, cs_y ); diff --git a/frame/include/level0/bli_randnp2s.h b/frame/include/level0/bli_randnp2s.h index 68d89ad966..7904f72aa5 100644 --- a/frame/include/level0/bli_randnp2s.h +++ b/frame/include/level0/bli_randnp2s.h @@ -42,6 +42,8 @@ { \ bli_drandnp2s( a ); \ } + +#if 0 #define bli_drandnp2s_prev( a ) \ { \ const double m_max = 3.0; \ @@ -95,6 +97,8 @@ down to float. */ \ a = r_val; \ } +#endif + #define bli_drandnp2s( a ) \ { \ const double m_max = 6.0; \ @@ -108,15 +112,19 @@ represents the largest power of two we will use to generate the random numbers. */ \ \ - /* Generate a random real number t on the interval: [0.0, 6.0]. */ \ - t = ( ( double ) rand() / ( double ) RAND_MAX ) * m_max2; \ -\ - /* Modify t to guarantee that is never equal to the upper bound of - the interval (in this case, 6.0). */ \ - if ( t == m_max2 ) t = t - 1.0; \ + do \ + { \ + /* Generate a random real number t on the interval: [0.0, 6.0]. */ \ + t = ( ( double ) rand() / ( double ) RAND_MAX ) * m_max2; \ \ - /* Transform the interval into the set of integers, {0,1,2,3,4,5}. */ \ - t = floor( t ); \ + /* Transform the interval into the set of integers, {0,1,2,3,4,5}. + Note that 6 is prohibited by the loop guard below. */ \ + t = floor( t ); \ + } \ + /* If t is ever equal to m_max2, we re-randomize. The guard against + m_max2 < t is for sanity and shouldn't happen, unless perhaps there + is weirdness in the typecasting to double when computing t above. */ \ + while ( m_max2 <= t ); \ \ /* Map values of t == 0 to a final value of 0. */ \ if ( t == 0.0 ) r_val = 0.0; \ @@ -126,7 +134,7 @@ \ double s_val; \ \ - /* Compute r_val = 2^s where s = +/-(t-1) = {-4,-3,-2,-1,0}. */ \ + /* Compute r_val = 2^s where s = -(t-1) = {-4,-3,-2,-1,0}. */ \ r_val = pow( 2.0, -(t - 1.0) ); \ \ /* Compute a random number to determine the sign of the final diff --git a/frame/include/level0/bli_scal2s_mxn.h b/frame/include/level0/bli_scal2s_mxn.h new file mode 100644 index 0000000000..db17eee4cb --- /dev/null +++ b/frame/include/level0/bli_scal2s_mxn.h @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SCAL2S_MXN_H +#define BLIS_SCAL2S_MXN_H + +// scal2s_mxn + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +BLIS_INLINE void PASTEMAC(ch,opname) \ + ( \ + const conj_t conjx, \ + const dim_t m, \ + const dim_t n, \ + ctype* restrict alpha, \ + ctype* restrict x, const inc_t rs_x, const inc_t cs_x, \ + ctype* restrict y, const inc_t rs_y, const inc_t cs_y \ + ) \ +{ \ + if ( bli_is_conj( conjx ) ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict xj = x + j*cs_x; \ + ctype* restrict yj = y + j*cs_y; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict xij = xj + i*rs_x; \ + ctype* restrict yij = yj + i*rs_y; \ +\ + PASTEMAC(ch,scal2js)( *alpha, *xij, *yij ); \ + } \ + } \ + } \ + else /* if ( bli_is_noconj( conjx ) ) */ \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict xj = x + j*cs_x; \ + ctype* restrict yj = y + j*cs_y; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict xij = xj + i*rs_x; \ + ctype* restrict yij = yj + i*rs_y; \ +\ + PASTEMAC(ch,scal2s)( *alpha, *xij, *yij ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( scal2s_mxn ) + +#endif diff --git a/frame/include/level0/bli_set0s_mxn.h b/frame/include/level0/bli_set0s_mxn.h index 699d6042a6..ed2f9b159f 100644 --- a/frame/include/level0/bli_set0s_mxn.h +++ b/frame/include/level0/bli_set0s_mxn.h @@ -41,7 +41,7 @@ // - The first char encodes the type of x. // - The second char encodes the type of y. -static void bli_sset0s_mxn( const dim_t m, const dim_t n, +BLIS_INLINE void bli_sset0s_mxn( const dim_t m, const dim_t n, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { for ( dim_t j = 0; j < n; ++j ) @@ -49,7 +49,7 @@ static void bli_sset0s_mxn( const dim_t m, const dim_t n, bli_sset0s( *(y + i*rs_y + j*cs_y) ); } -static void bli_dset0s_mxn( const dim_t m, const dim_t n, +BLIS_INLINE void bli_dset0s_mxn( const dim_t m, const dim_t n, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { for ( dim_t j = 0; j < n; ++j ) @@ -57,7 +57,7 @@ static void bli_dset0s_mxn( const dim_t m, const dim_t n, bli_dset0s( *(y + i*rs_y + j*cs_y) ); } -static void bli_cset0s_mxn( const dim_t m, const dim_t n, +BLIS_INLINE void bli_cset0s_mxn( const dim_t m, const dim_t n, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { for ( dim_t j = 0; j < n; ++j ) @@ -65,7 +65,7 @@ static void bli_cset0s_mxn( const dim_t m, const dim_t n, bli_cset0s( *(y + i*rs_y + j*cs_y) ); } -static void bli_zset0s_mxn( const dim_t m, const dim_t n, +BLIS_INLINE void bli_zset0s_mxn( const dim_t m, const dim_t n, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { for ( dim_t j = 0; j < n; ++j ) diff --git a/frame/include/level0/bli_xpbys_mxn.h b/frame/include/level0/bli_xpbys_mxn.h index 8b6b82d44a..f23df17a2c 100644 --- a/frame/include/level0/bli_xpbys_mxn.h +++ b/frame/include/level0/bli_xpbys_mxn.h @@ -45,7 +45,7 @@ // -- (xby) = (?ss) ------------------------------------------------------------ -static void bli_sssxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sssxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict beta, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -80,7 +80,7 @@ static void bli_sssxpbys_mxn( const dim_t m, const dim_t n, float* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dssxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dssxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict beta, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -115,7 +115,7 @@ static void bli_dssxpbys_mxn( const dim_t m, const dim_t n, double* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cssxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cssxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict beta, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -150,7 +150,7 @@ static void bli_cssxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zssxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zssxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict beta, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -188,7 +188,7 @@ static void bli_zssxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x // -- (xby) = (?dd) ------------------------------------------------------------ -static void bli_sddxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sddxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict beta, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -223,7 +223,7 @@ static void bli_sddxpbys_mxn( const dim_t m, const dim_t n, float* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dddxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dddxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict beta, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -258,7 +258,7 @@ static void bli_dddxpbys_mxn( const dim_t m, const dim_t n, double* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cddxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cddxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict beta, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -293,7 +293,7 @@ static void bli_cddxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zddxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zddxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict beta, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -331,7 +331,7 @@ static void bli_zddxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x // -- (xby) = (?cc) ------------------------------------------------------------ -static void bli_sccxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sccxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict beta, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -366,7 +366,7 @@ static void bli_sccxpbys_mxn( const dim_t m, const dim_t n, float* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dccxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dccxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict beta, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -401,7 +401,7 @@ static void bli_dccxpbys_mxn( const dim_t m, const dim_t n, double* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_cccxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cccxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict beta, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -436,7 +436,7 @@ static void bli_cccxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zccxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zccxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict beta, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -474,7 +474,7 @@ static void bli_zccxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x // -- (xby) = (?zz) ------------------------------------------------------------ -static void bli_szzxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_szzxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict beta, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -509,7 +509,7 @@ static void bli_szzxpbys_mxn( const dim_t m, const dim_t n, float* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_dzzxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dzzxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict beta, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -544,7 +544,7 @@ static void bli_dzzxpbys_mxn( const dim_t m, const dim_t n, double* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_czzxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_czzxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict beta, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -579,7 +579,7 @@ static void bli_czzxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x *(y + ii*rs_y + jj*cs_y) ); } } -static void bli_zzzxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zzzxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict beta, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { @@ -617,25 +617,25 @@ static void bli_zzzxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x -static void bli_sxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_sxpbys_mxn( const dim_t m, const dim_t n, float* restrict x, const inc_t rs_x, const inc_t cs_x, float* restrict beta, float* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_sssxpbys_mxn( m, n, x, rs_x, cs_x, beta, y, rs_y, cs_y ); } -static void bli_dxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_dxpbys_mxn( const dim_t m, const dim_t n, double* restrict x, const inc_t rs_x, const inc_t cs_x, double* restrict beta, double* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_dddxpbys_mxn( m, n, x, rs_x, cs_x, beta, y, rs_y, cs_y ); } -static void bli_cxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_cxpbys_mxn( const dim_t m, const dim_t n, scomplex* restrict x, const inc_t rs_x, const inc_t cs_x, scomplex* restrict beta, scomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { bli_cccxpbys_mxn( m, n, x, rs_x, cs_x, beta, y, rs_y, cs_y ); } -static void bli_zxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, +BLIS_INLINE void bli_zxpbys_mxn( const dim_t m, const dim_t n, dcomplex* restrict x, const inc_t rs_x, const inc_t cs_x, dcomplex* restrict beta, dcomplex* restrict y, const inc_t rs_y, const inc_t cs_y ) { diff --git a/frame/include/level0/io/bli_scal2ios.h b/frame/include/level0/old/io/bli_scal2ios.h similarity index 100% rename from frame/include/level0/io/bli_scal2ios.h rename to frame/include/level0/old/io/bli_scal2ios.h diff --git a/frame/include/level0/io/bli_scal2jios.h b/frame/include/level0/old/io/bli_scal2jios.h similarity index 100% rename from frame/include/level0/io/bli_scal2jios.h rename to frame/include/level0/old/io/bli_scal2jios.h diff --git a/frame/include/level0/ri3/bli_copyjri3s.h b/frame/include/level0/old/ri3/bli_copyjri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_copyjri3s.h rename to frame/include/level0/old/ri3/bli_copyjri3s.h diff --git a/frame/include/level0/ri3/bli_copyri3s.h b/frame/include/level0/old/ri3/bli_copyri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_copyri3s.h rename to frame/include/level0/old/ri3/bli_copyri3s.h diff --git a/frame/include/level0/ri3/bli_scal2jri3s.h b/frame/include/level0/old/ri3/bli_scal2jri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_scal2jri3s.h rename to frame/include/level0/old/ri3/bli_scal2jri3s.h diff --git a/frame/include/level0/ri3/bli_scal2ri3s.h b/frame/include/level0/old/ri3/bli_scal2ri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_scal2ri3s.h rename to frame/include/level0/old/ri3/bli_scal2ri3s.h diff --git a/frame/include/level0/ri3/bli_scal2ri3s_mxn.h b/frame/include/level0/old/ri3/bli_scal2ri3s_mxn.h similarity index 98% rename from frame/include/level0/ri3/bli_scal2ri3s_mxn.h rename to frame/include/level0/old/ri3/bli_scal2ri3s_mxn.h index 1e1d3d8bac..2316f0738c 100644 --- a/frame/include/level0/ri3/bli_scal2ri3s_mxn.h +++ b/frame/include/level0/old/ri3/bli_scal2ri3s_mxn.h @@ -37,7 +37,7 @@ // scal2ri3s_mxn -static void bli_cscal2ri3s_mxn +BLIS_INLINE void bli_cscal2ri3s_mxn ( const conj_t conjx, const dim_t m, @@ -108,7 +108,7 @@ static void bli_cscal2ri3s_mxn } } -static void bli_zscal2ri3s_mxn +BLIS_INLINE void bli_zscal2ri3s_mxn ( const conj_t conjx, const dim_t m, diff --git a/frame/include/level0/rih/bli_scal2rihs_mxn.h b/frame/include/level0/old/rih/bli_scal2rihs_mxn.h similarity index 98% rename from frame/include/level0/rih/bli_scal2rihs_mxn.h rename to frame/include/level0/old/rih/bli_scal2rihs_mxn.h index 8485e40a97..ca117b85d9 100644 --- a/frame/include/level0/rih/bli_scal2rihs_mxn.h +++ b/frame/include/level0/old/rih/bli_scal2rihs_mxn.h @@ -37,7 +37,7 @@ // scal2rihs_mxn -static void bli_cscal2rihs_mxn +BLIS_INLINE void bli_cscal2rihs_mxn ( const pack_t schema, const conj_t conjx, @@ -158,7 +158,7 @@ static void bli_cscal2rihs_mxn } } -static void bli_zscal2rihs_mxn +BLIS_INLINE void bli_zscal2rihs_mxn ( const pack_t schema, const conj_t conjx, diff --git a/frame/include/level0/rih/bli_scal2rihs_mxn_diag.h b/frame/include/level0/old/rih/bli_scal2rihs_mxn_diag.h similarity index 100% rename from frame/include/level0/rih/bli_scal2rihs_mxn_diag.h rename to frame/include/level0/old/rih/bli_scal2rihs_mxn_diag.h diff --git a/frame/include/level0/rih/bli_scal2rihs_mxn_uplo.h b/frame/include/level0/old/rih/bli_scal2rihs_mxn_uplo.h similarity index 100% rename from frame/include/level0/rih/bli_scal2rihs_mxn_uplo.h rename to frame/include/level0/old/rih/bli_scal2rihs_mxn_uplo.h diff --git a/frame/include/level0/rih/bli_setrihs_mxn_diag.h b/frame/include/level0/old/rih/bli_setrihs_mxn_diag.h similarity index 100% rename from frame/include/level0/rih/bli_setrihs_mxn_diag.h rename to frame/include/level0/old/rih/bli_setrihs_mxn_diag.h diff --git a/frame/include/level0/ro/bli_scal2jros.h b/frame/include/level0/old/ro/bli_scal2jros.h similarity index 100% rename from frame/include/level0/ro/bli_scal2jros.h rename to frame/include/level0/old/ro/bli_scal2jros.h diff --git a/frame/include/level0/ro/bli_scal2ros.h b/frame/include/level0/old/ro/bli_scal2ros.h similarity index 100% rename from frame/include/level0/ro/bli_scal2ros.h rename to frame/include/level0/old/ro/bli_scal2ros.h diff --git a/frame/include/level0/rpi/bli_scal2jrpis.h b/frame/include/level0/old/rpi/bli_scal2jrpis.h similarity index 100% rename from frame/include/level0/rpi/bli_scal2jrpis.h rename to frame/include/level0/old/rpi/bli_scal2jrpis.h diff --git a/frame/include/level0/rpi/bli_scal2rpis.h b/frame/include/level0/old/rpi/bli_scal2rpis.h similarity index 100% rename from frame/include/level0/rpi/bli_scal2rpis.h rename to frame/include/level0/old/rpi/bli_scal2rpis.h diff --git a/frame/include/level0/ri/bli_scal2ris_mxn.h b/frame/include/level0/ri/bli_scal2ris_mxn.h index bd7422e81b..85b242146b 100644 --- a/frame/include/level0/ri/bli_scal2ris_mxn.h +++ b/frame/include/level0/ri/bli_scal2ris_mxn.h @@ -37,7 +37,7 @@ // scal2ris_mxn -static void bli_cscal2ris_mxn +BLIS_INLINE void bli_cscal2ris_mxn ( const conj_t conjx, const dim_t m, @@ -103,7 +103,7 @@ static void bli_cscal2ris_mxn } } -static void bli_zscal2ris_mxn +BLIS_INLINE void bli_zscal2ris_mxn ( const conj_t conjx, const dim_t m, diff --git a/frame/ind/cntx/bli_cntx_ind_stage.c b/frame/ind/cntx/bli_cntx_ind_stage.c deleted file mode 100644 index 671be681d7..0000000000 --- a/frame/ind/cntx/bli_cntx_ind_stage.c +++ /dev/null @@ -1,148 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -typedef void (*cntx_stage_ft)( dim_t stage, cntx_t* cntx ); - -static void* bli_cntx_ind_stage_fp[BLIS_NUM_IND_METHODS] = -{ -/* 3mh */ bli_cntx_3mh_stage, -/* 3m1 */ bli_cntx_3m1_stage, -/* 4mh */ bli_cntx_4mh_stage, -/* 4mb */ bli_cntx_4mb_stage, -/* 4m1 */ bli_cntx_4m1_stage, -/* 1m */ bli_cntx_1m_stage, -/* nat */ bli_cntx_nat_stage -}; - - -// ----------------------------------------------------------------------------- - -// Execute the context initialization/finalization function associated -// with a given induced method. - -void bli_cntx_ind_stage( ind_t method, dim_t stage, cntx_t* cntx ) -{ - cntx_stage_ft func = bli_cntx_ind_stage_fp[ method ]; - - func( stage, cntx ); -} - -// ----------------------------------------------------------------------------- - -// These functions modify a context, if needed, for the particular "stage" of -// the induced method execution. Some induced methods do not make use of this -// feature. NOTE: ANY INDUCED METHOD THAT HAS A NON-EMPTY _stage() FUNCTION -// IS NOT THREAT-SAFE FOR APPLICATION-LEVEL THREADING. - -// ----------------------------------------------------------------------------- - -void bli_cntx_3mh_stage( dim_t stage, cntx_t* cntx ) -{ - // Set the pack_t schemas as a function of the stage of execution. - if ( stage == 0 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RO, cntx ); - } - else if ( stage == 1 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_IO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_IO, cntx ); - } - else // if ( stage == 2 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RPI, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RPI, cntx ); - } -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_3m1_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_4mh_stage( dim_t stage, cntx_t* cntx ) -{ - // Set the pack_t schemas as a function of the stage of execution. - if ( stage == 0 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RO, cntx ); - } - else if ( stage == 1 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_IO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_IO, cntx ); - } - else if ( stage == 2 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_IO, cntx ); - } - else // if ( stage == 3 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_IO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RO, cntx ); - } -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_4mb_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_4m1_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_1m_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_nat_stage( dim_t stage, cntx_t* cntx ) -{ -} - diff --git a/frame/ind/cntx/bli_cntx_ind_stage.h b/frame/ind/cntx/bli_cntx_ind_stage.h deleted file mode 100644 index 124421665a..0000000000 --- a/frame/ind/cntx/bli_cntx_ind_stage.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -void bli_cntx_ind_stage( ind_t method, dim_t stage, cntx_t* cntx ); - -void bli_cntx_3mh_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_3m1_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4mh_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4mb_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4m1_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_1m_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_nat_stage( dim_t stage, cntx_t* cntx ); - diff --git a/frame/ind/oapi/bli_l3_3m4m1m_oapi.c b/frame/ind/oapi/bli_l3_3m4m1m_oapi.c deleted file mode 100644 index 087e1beef1..0000000000 --- a/frame/ind/oapi/bli_l3_3m4m1m_oapi.c +++ /dev/null @@ -1,443 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -// -- gemm/her2k/syr2k --------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( c ); \ - obj_t* beta_use = beta; \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( c ) ) \ - { \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - /* A temporary hack to easily specify the 1m algorithm (block-panel or - panel-block). */ \ -/* - if ( PASTEMAC(opname,imeth) == bli_gemm1m ) \ - { \ - bli_gemm1mbp( alpha, a, b, beta, c ); \ - return; \ - } \ - else if ( PASTEMAC(opname,imeth) == bli_gemm3m1 ) \ - { \ - bli_gemm1mpb( alpha, a, b, beta, c ); \ - return; \ - } \ -*/ \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* 3mh and 4mh change the context for each stage, and so in order to - remain thread-safe, we must make a local copy of the context for - those induced methods. */ \ - cntx_t cntx_l; \ - if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* For multi-stage methods, use BLIS_ONE as beta after the first - stage. */ \ - if ( i > 0 ) beta_use = &BLIS_ONE; \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( alpha, a, b, beta_use, c, cntx, rntm, NULL ); \ - } \ -} - -// gemm -GENFRONT( gemm, gemm, 3mh, 3 ) -GENFRONT( gemm, gemm, 3m1, 1 ) -GENFRONT( gemm, gemm, 4mh, 4 ) -GENFRONT( gemm, gemm, 4mb, 1 ) -GENFRONT( gemm, gemm, 4m1, 1 ) -GENFRONT( gemm, gemm, 1m, 1 ) - -// her2k -GENFRONT( her2k, gemm, 3mh, 3 ) -GENFRONT( her2k, gemm, 3m1, 1 ) -GENFRONT( her2k, gemm, 4mh, 4 ) -//GENFRONT( her2k, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( her2k, gemm, 4m1, 1 ) -GENFRONT( her2k, gemm, 1m, 1 ) - -// syr2k -GENFRONT( syr2k, gemm, 3mh, 3 ) -GENFRONT( syr2k, gemm, 3m1, 1 ) -GENFRONT( syr2k, gemm, 4mh, 4 ) -//GENFRONT( syr2k, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( syr2k, gemm, 4m1, 1 ) -GENFRONT( syr2k, gemm, 1m, 1 ) - - -// -- hemm/symm/trmm3 ---------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( c ); \ - obj_t* beta_use = beta; \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( c ) ) \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* 3mh and 4mh change the context for each stage, and so in order to - remain thread-safe, we must make a local copy of the context for - those induced methods. */ \ - cntx_t cntx_l; \ - if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* For multi-stage methods, use BLIS_ONE as beta after the first - stage. */ \ - if ( i > 0 ) beta_use = &BLIS_ONE; \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( side, alpha, a, b, beta_use, c, cntx, rntm, NULL ); \ - } \ -} - -// hemm -GENFRONT( hemm, gemm, 3mh, 3 ) -GENFRONT( hemm, gemm, 3m1, 1 ) -GENFRONT( hemm, gemm, 4mh, 4 ) -//GENFRONT( hemm, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( hemm, gemm, 4m1, 1 ) -GENFRONT( hemm, gemm, 1m, 1 ) - -// symm -GENFRONT( symm, gemm, 3mh, 3 ) -GENFRONT( symm, gemm, 3m1, 1 ) -GENFRONT( symm, gemm, 4mh, 4 ) -//GENFRONT( symm, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( symm, gemm, 4m1, 1 ) -GENFRONT( symm, gemm, 1m, 1 ) - -// trmm3 -GENFRONT( trmm3, gemm, 3mh, 3 ) -GENFRONT( trmm3, gemm, 3m1, 1 ) -GENFRONT( trmm3, gemm, 4mh, 4 ) -//GENFRONT( trmm3, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( trmm3, gemm, 4m1, 1 ) -GENFRONT( trmm3, gemm, 1m, 1 ) - - -// -- herk/syrk ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( c ); \ - obj_t* beta_use = beta; \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( c ) ) \ - { \ - PASTEMAC(opname,nat)( alpha, a, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* 3mh and 4mh change the context for each stage, and so in order to - remain thread-safe, we must make a local copy of the context for - those induced methods. */ \ - cntx_t cntx_l; \ - if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* For multi-stage methods, use BLIS_ONE as beta after the first - stage. */ \ - if ( i > 0 ) beta_use = &BLIS_ONE; \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( alpha, a, beta_use, c, cntx, rntm, NULL ); \ - } \ -} - -// herk -GENFRONT( herk, gemm, 3mh, 3 ) -GENFRONT( herk, gemm, 3m1, 1 ) -GENFRONT( herk, gemm, 4mh, 4 ) -//GENFRONT( herk, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( herk, gemm, 4m1, 1 ) -GENFRONT( herk, gemm, 1m, 1 ) - -// syrk -GENFRONT( syrk, gemm, 3mh, 3 ) -GENFRONT( syrk, gemm, 3m1, 1 ) -GENFRONT( syrk, gemm, 4mh, 4 ) -//GENFRONT( syrk, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( syrk, gemm, 4m1, 1 ) -GENFRONT( syrk, gemm, 1m, 1 ) - - -// -- trmm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( b ); \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( b ) ) \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( side, alpha, a, b, cntx, rntm, NULL ); \ - } \ -} - -// trmm -//GENFRONT( trmm, gemm, 3mh, 3 ) // Unimplementable. -GENFRONT( trmm, gemm, 3m1, 1 ) -//GENFRONT( trmm, gemm, 4mh, 4 ) // Unimplementable. -//GENFRONT( trmm, gemm, 4mb, 1 ) // Unimplementable. -GENFRONT( trmm, gemm, 4m1, 1 ) -GENFRONT( trmm, gemm, 1m, 1 ) - - -// -- trsm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( b ); \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( b ) ) \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - { \ - /* NOTE: trsm cannot be implemented via any induced method that - needs to execute in stages (e.g. 3mh, 4mh). */ \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( side, alpha, a, b, cntx, rntm, NULL ); \ - } \ -} - -// trsm -//GENFRONT( trmm, trsm, 3mh, 3 ) // Unimplementable. -GENFRONT( trsm, trsm, 3m1, 1 ) -//GENFRONT( trmm, trsm, 4mh, 4 ) // Unimplementable. -//GENFRONT( trmm, trsm, 4mb, 1 ) // Unimplementable. -GENFRONT( trsm, trsm, 4m1, 1 ) -GENFRONT( trsm, trsm, 1m, 1 ) - diff --git a/frame/ind/oapi/bli_l3_ind_oapi.c b/frame/ind/oapi/bli_l3_ind_oapi.c deleted file mode 100644 index 2137530196..0000000000 --- a/frame/ind/oapi/bli_l3_ind_oapi.c +++ /dev/null @@ -1,174 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -// -- gemm/her2k/syr2k --------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( c ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( alpha, a, b, beta, c, cntx, rntm ); \ -} - -GENFRONT( gemm, ind ) -GENFRONT( her2k, ind ) -GENFRONT( syr2k, ind ) - - -// -- hemm/symm/trmm3 ---------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( c ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( side, alpha, a, b, beta, c, cntx, rntm ); \ -} - -GENFRONT( hemm, ind ) -GENFRONT( symm, ind ) -GENFRONT( trmm3, ind ) - - -// -- herk/syrk ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( c ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( alpha, a, beta, c, cntx, rntm ); \ -} - -GENFRONT( herk, ind ) -GENFRONT( syrk, ind ) - - -// -- trmm/trsm ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( b ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( side, alpha, a, b, cntx, rntm ); \ -} - -GENFRONT( trmm, ind ) -GENFRONT( trsm, ind ) - diff --git a/frame/ind/oapi/bli_l3_ind_oapi.h b/frame/ind/oapi/bli_l3_ind_oapi.h deleted file mode 100644 index d4767925de..0000000000 --- a/frame/ind/oapi/bli_l3_ind_oapi.h +++ /dev/null @@ -1,96 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - - -// -// Generate object-based prototypes for induced methods that work for -// trmm and trsm (ie: two-operand operations). -// -#undef GENPROT -#define GENPROT( imeth ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trmm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trsm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); - -GENPROT( nat ) -GENPROT( ind ) -GENPROT( 3m1 ) -GENPROT( 4m1 ) -GENPROT( 1m ) - - -// -// Generate object-based prototypes for induced methods that do NOT work -// for trmm and trsm (ie: two-operand operations). -// -#undef GENPROT_NO2OP -#define GENPROT_NO2OP( imeth ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); - -GENPROT_NO2OP( 3mh ) -GENPROT_NO2OP( 4mh ) -GENPROT_NO2OP( 4mb ) - - -// -// Generate object-based prototypes for 1m methods that specify an algorithm -// (e.g., block-panel or panel-block). -// - -/* -#undef GENPROT -#define GENPROT( imeth, alg ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC2(gemm,imeth,alg) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c ); \ -*/ - -//GENPROT( 1m, bp ) -//GENPROT( 1m, pb ) - diff --git a/frame/ind/oapi/bli_l3_nat_oapi.c b/frame/ind/oapi/bli_l3_nat_oapi.c deleted file mode 100644 index 52b7e98ad6..0000000000 --- a/frame/ind/oapi/bli_l3_nat_oapi.c +++ /dev/null @@ -1,234 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -// NOTE: The function definitions in this file can be consolidated with the -// definitions for the other induced methods. The only advantage of keeping -// them separate is that it allows us to avoid the very small loop overhead -// of executing one iteration of a for loop, plus the overhead of calling a -// function that does nothing (ie: the _cntx_init_stage() function). - -// -- gemm/her2k/syr2k --------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - alpha, a, b, beta, c, cntx, rntm, NULL \ - ); \ -} - -// If a sandbox was enabled, do not define bli_gemmnat() since it will be -// defined in the sandbox environment. -#ifndef BLIS_ENABLE_SANDBOX -GENFRONT( gemm, gemm, nat ) -#endif -GENFRONT( her2k, gemm, nat ) -GENFRONT( syr2k, gemm, nat ) - - -// -- hemm/symm/trmm3 ---------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - side, alpha, a, b, beta, c, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( hemm, gemm, nat ) -GENFRONT( symm, gemm, nat ) -GENFRONT( trmm3, gemm, nat ) - - -// -- herk/syrk ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - alpha, a, beta, c, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( herk, gemm, nat ) -GENFRONT( syrk, gemm, nat ) - - -// -- trmm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - side, alpha, a, b, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( trmm, gemm, nat ) - - -// -- trsm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - side, alpha, a, b, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( trsm, trsm, nat ) - diff --git a/frame/ind/tapi/bli_l3_ind_tapi.c b/frame/ind/tapi/bli_l3_ind_tapi.c deleted file mode 100644 index 9ca7746bc0..0000000000 --- a/frame/ind/tapi/bli_l3_ind_tapi.c +++ /dev/null @@ -1,664 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -// -- gemm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - trans_t transa, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( gemm3mh ) -INSERT_GENTFUNC_BASIC0( gemm3m1 ) -INSERT_GENTFUNC_BASIC0( gemm4mh ) -INSERT_GENTFUNC_BASIC0( gemm4mb ) -INSERT_GENTFUNC_BASIC0( gemm4m1 ) -INSERT_GENTFUNC_BASIC0( gemm1m ) - - -// -- hemm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - conj_t conja, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_conj( conja, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( hemm3mh ) -INSERT_GENTFUNC_BASIC0( hemm3m1 ) -INSERT_GENTFUNC_BASIC0( hemm4mh ) -INSERT_GENTFUNC_BASIC0( hemm4m1 ) -INSERT_GENTFUNC_BASIC0( hemm1m ) - - -// -- herk --------------------------------------------------------------------- - -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - dim_t m, \ - dim_t k, \ - ctype_r* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype_r* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, betao, co; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt_r, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNCR_BASIC0( herk3mh ) -INSERT_GENTFUNCR_BASIC0( herk3m1 ) -INSERT_GENTFUNCR_BASIC0( herk4mh ) -INSERT_GENTFUNCR_BASIC0( herk4m1 ) -INSERT_GENTFUNCR_BASIC0( herk1m ) - - -// -- her2k -------------------------------------------------------------------- - -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - trans_t transb, \ - dim_t m, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype_r* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNCR_BASIC0( her2k3mh ) -INSERT_GENTFUNCR_BASIC0( her2k3m1 ) -INSERT_GENTFUNCR_BASIC0( her2k4mh ) -INSERT_GENTFUNCR_BASIC0( her2k4m1 ) -INSERT_GENTFUNCR_BASIC0( her2k1m ) - - -// -- symm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - conj_t conja, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_conj( conja, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( symm3mh ) -INSERT_GENTFUNC_BASIC0( symm3m1 ) -INSERT_GENTFUNC_BASIC0( symm4mh ) -INSERT_GENTFUNC_BASIC0( symm4m1 ) -INSERT_GENTFUNC_BASIC0( symm1m ) - - -// -- syrk --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - dim_t m, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, betao, co; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( syrk3mh ) -INSERT_GENTFUNC_BASIC0( syrk3m1 ) -INSERT_GENTFUNC_BASIC0( syrk4mh ) -INSERT_GENTFUNC_BASIC0( syrk4m1 ) -INSERT_GENTFUNC_BASIC0( syrk1m ) - - -// -- syr2k -------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - trans_t transb, \ - dim_t m, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( syr2k3mh ) -INSERT_GENTFUNC_BASIC0( syr2k3m1 ) -INSERT_GENTFUNC_BASIC0( syr2k4mh ) -INSERT_GENTFUNC_BASIC0( syr2k4m1 ) -INSERT_GENTFUNC_BASIC0( syr2k1m ) - - -// -- trmm3 -------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( trmm33mh ) -INSERT_GENTFUNC_BASIC0( trmm33m1 ) -INSERT_GENTFUNC_BASIC0( trmm34mh ) -INSERT_GENTFUNC_BASIC0( trmm34m1 ) -INSERT_GENTFUNC_BASIC0( trmm31m ) - - -// -- trmm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo; \ -\ - dim_t mn_a; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, n, b, rs_b, cs_b, &bo ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( trmm3m1 ) -INSERT_GENTFUNC_BASIC0( trmm4m1 ) -INSERT_GENTFUNC_BASIC0( trmm1m ) - - -// -- trsm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo; \ -\ - dim_t mn_a; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, n, b, rs_b, cs_b, &bo ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( trsm3m1 ) -INSERT_GENTFUNC_BASIC0( trsm4m1 ) -INSERT_GENTFUNC_BASIC0( trsm1m ) - diff --git a/frame/thread/bli_l3_decor.h b/frame/thread/bli_l3_decor.h new file mode 100644 index 0000000000..0b09189a69 --- /dev/null +++ b/frame/thread/bli_l3_decor.h @@ -0,0 +1,77 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_DECOR_H +#define BLIS_L3_DECOR_H + +// -- conventional definitions ------------------------------------------------- + +// Level-3 internal function type. +typedef void (*l3int_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); + +// Level-3 thread decorator prototype. +void bli_l3_thread_decorator + ( + l3int_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ); + +// Include definitions specific to the method of multithreading for the +// conventional code path. +#include "bli_l3_decor_single.h" +#include "bli_l3_decor_openmp.h" +#include "bli_l3_decor_pthreads.h" + +#endif + diff --git a/frame/thread/bli_l3_decor_openmp.c b/frame/thread/bli_l3_decor_openmp.c new file mode 100644 index 0000000000..5b40d06143 --- /dev/null +++ b/frame/thread/bli_l3_decor_openmp.c @@ -0,0 +1,249 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy function bli_l3_thread_entry(), which is needed in the +// pthreads version, so that when building Windows DLLs (with OpenMP enabled +// or no multithreading) we don't risk having an unresolved symbol. +void* bli_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bli_l3_thread_decorator + ( + l3int_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + // This is part of a hack to support mixed domain in bli_gemm_front(). + // Sometimes we need to specify a non-standard schema for A and B, and + // we decided to transmit them via the schema field in the obj_t's + // rather than pass them in as function parameters. Once the values + // have been read, we immediately reset them back to their expected + // values for unpacked objects. + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); + + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + #ifdef PRINT_THRINFO + err_t r_val; + thrinfo_t** threads = bli_malloc_intl( n_threads * sizeof( thrinfo_t* ), &r_val ); + #endif + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + + obj_t a_t, b_t, c_t; + cntl_t* cntl_use; + thrinfo_t* thread; + + // Alias thread-local copies of A, B, and C. These will be the objects + // we pass down the algorithmic function stack. Making thread-local + // aliases is highly recommended in case a thread needs to change any + // of the properties of an object without affecting other threads' + // objects. + bli_obj_alias_to( a, &a_t ); + bli_obj_alias_to( b, &b_t ); + bli_obj_alias_to( c, &c_t ); + + // Create a default control tree for the operation, if needed. + bli_l3_cntl_create_if( family, schema_a, schema_b, + &a_t, &b_t, &c_t, rntm_p, cntl, &cntl_use ); + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); + +#if 1 + func + ( + alpha, + &a_t, + &b_t, + beta, + &c_t, + cntx, + rntm_p, + cntl_use, + thread + ); +#else + bli_thrinfo_grow_tree + ( + rntm_p, + cntl_use, + thread + ); +#endif + + // Free the thread's local control tree. + bli_l3_cntl_free( rntm_p, cntl_use, thread ); + + #ifdef PRINT_THRINFO + threads[tid] = thread; + #else + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( rntm_p, thread ); + #endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + #ifdef PRINT_THRINFO + if ( family != BLIS_TRSM ) bli_l3_thrinfo_print_gemm_paths( threads ); + else bli_l3_thrinfo_print_trsm_paths( threads ); + exit(1); + #endif + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +// ----------------------------------------------------------------------------- + +void bli_l3_thread_decorator_thread_check + ( + dim_t n_threads, + dim_t tid, + thrcomm_t* gl_comm, + rntm_t* rntm + ) +{ + dim_t n_threads_real = omp_get_num_threads(); + + // Check if the number of OpenMP threads created within this parallel + // region is different from the number of threads that were requested + // of BLIS. This inequality may trigger when, for example, the + // following conditions are satisfied: + // - an application is executing an OpenMP parallel region in which + // BLIS is invoked, + // - BLIS is configured for multithreading via OpenMP, + // - OMP_NUM_THREADS = t > 1, + // - the number of threads requested of BLIS (regardless of method) + // is p <= t, + // - OpenMP nesting is disabled. + // In this situation, the application spawns t threads. Each application + // thread calls gemm (for example). Each gemm will attempt to spawn p + // threads via OpenMP. However, since nesting is disabled, the OpenMP + // implementation finds that t >= p threads are already spawned, and + // thus it doesn't spawn *any* additional threads for each gemm. + if ( n_threads_real != n_threads ) + { + // If the number of threads active in the current region is not + // equal to the number requested of BLIS, we then only continue + // if the number of threads in the current region is 1. If, for + // example, BLIS requested 4 threads but only got 3, then we + // abort(). + //if ( tid == 0 ) + //{ + if ( n_threads_real != 1 ) + { + bli_print_msg( "A different number of threads was " + "created than was requested.", + __FILE__, __LINE__ ); + bli_abort(); + } + + //n_threads = 1; // not needed since it has no effect? + bli_thrcomm_init( 1, gl_comm ); + bli_rntm_set_num_threads_only( 1, rntm ); + bli_rntm_set_ways_only( 1, 1, 1, 1, 1, rntm ); + //} + + // Synchronize all threads and continue. + _Pragma( "omp barrier" ) + } +} + +#endif + diff --git a/frame/thread/bli_l3_decor_openmp.h b/frame/thread/bli_l3_decor_openmp.h new file mode 100644 index 0000000000..80dbe5374e --- /dev/null +++ b/frame/thread/bli_l3_decor_openmp.h @@ -0,0 +1,53 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_DECOR_OPENMP_H +#define BLIS_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +void bli_l3_thread_decorator_thread_check + ( + dim_t n_threads, + dim_t tid, + thrcomm_t* gl_comm, + rntm_t* rntm + ); + +#endif + +#endif + diff --git a/frame/thread/bli_l3_decor_pthreads.c b/frame/thread/bli_l3_decor_pthreads.c new file mode 100644 index 0000000000..89b6ea1187 --- /dev/null +++ b/frame/thread/bli_l3_decor_pthreads.c @@ -0,0 +1,254 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3int_t func; + opid_t family; + pack_t schema_a; + pack_t schema_b; + obj_t* alpha; + obj_t* a; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + cntl_t* cntl; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point for additional threads +void* bli_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3int_t func = data->func; + opid_t family = data->family; + pack_t schema_a = data->schema_a; + pack_t schema_b = data->schema_b; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + cntl_t* cntl = data->cntl; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + obj_t a_t, b_t, c_t; + cntl_t* cntl_use; + thrinfo_t* thread; + + // Alias thread-local copies of A, B, and C. These will be the objects + // we pass down the algorithmic function stack. Making thread-local + // aliases is highly recommended in case a thread needs to change any + // of the properties of an object without affecting other threads' + // objects. + bli_obj_alias_to( a, &a_t ); + bli_obj_alias_to( b, &b_t ); + bli_obj_alias_to( c, &c_t ); + + // Create a default control tree for the operation, if needed. + bli_l3_cntl_create_if( family, schema_a, schema_b, + &a_t, &b_t, &c_t, rntm_p, cntl, &cntl_use ); + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); + + func + ( + alpha, + &a_t, + &b_t, + beta, + &c_t, + cntx, + rntm_p, + cntl_use, + thread + ); + + // Free the thread's local control tree. + bli_l3_cntl_free( rntm_p, cntl_use, thread ); + + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bli_l3_thread_decorator + ( + l3int_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + err_t r_val; + + // This is part of a hack to support mixed domain in bli_gemm_front(). + // Sometimes we need to specify a non-standard schema for A and B, and + // we decided to transmit them via the schema field in the obj_t's + // rather than pass them in as function parameters. Once the values + // have been read, we immediately reset them back to their expected + // values for unpacked objects. + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); + + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].schema_a = schema_a; + datas[tid].schema_b = schema_b; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].cntl = cntl; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bli_l3_thread_entry, &datas[tid] ); + else + bli_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/frame/thread/bli_l3_decor_pthreads.h b/frame/thread/bli_l3_decor_pthreads.h new file mode 100644 index 0000000000..772e05ca78 --- /dev/null +++ b/frame/thread/bli_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_DECOR_PTHREADS_H +#define BLIS_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bli_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/frame/thread/bli_l3_decor_single.c b/frame/thread/bli_l3_decor_single.c new file mode 100644 index 0000000000..51474f0eee --- /dev/null +++ b/frame/thread/bli_l3_decor_single.c @@ -0,0 +1,150 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +void bli_l3_thread_decorator + ( + l3int_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + // This is part of a hack to support mixed domain in bli_gemm_front(). + // Sometimes we need to specify a non-standard schema for A and B, and + // we decided to transmit them via the schema field in the obj_t's + // rather than pass them in as function parameters. Once the values + // have been read, we immediately reset them back to their expected + // values for unpacked objects. + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); + + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we can create the global comm below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_pba_rntm_set_pba( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + cntl_t* cntl_use; + thrinfo_t* thread; + + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + + // NOTE: Unlike with the _openmp.c and _pthreads.c variants, we don't + // need to alias objects for A, B, and C since they were already aliased + // in bli_*_front(). However, we may add aliasing here in the future so + // that, with all three (_single.c, _openmp.c, _pthreads.c) implementations + // consistently providing local aliases, we can then eliminate aliasing + // elsewhere. + + // Create a default control tree for the operation, if needed. + bli_l3_cntl_create_if( family, schema_a, schema_b, + a, b, c, rntm_p, cntl, &cntl_use ); + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + cntl_use, + thread + ); + + // Free the thread's local control tree. + bli_l3_cntl_free( rntm_p, cntl_use, thread ); + + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/frame/thread/bli_l3_decor_single.h b/frame/thread/bli_l3_decor_single.h new file mode 100644 index 0000000000..481763a908 --- /dev/null +++ b/frame/thread/bli_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_DECOR_SINGLE_H +#define BLIS_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/frame/thread/bli_l3_sup_decor.h b/frame/thread/bli_l3_sup_decor.h new file mode 100644 index 0000000000..a001e5b743 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor.h @@ -0,0 +1,75 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_DECOR_H +#define BLIS_L3_SUP_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef err_t (*l3supint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +err_t bli_l3_sup_thread_decorator + ( + l3supint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading for the +// sup code path. +#include "bli_l3_sup_decor_single.h" +#include "bli_l3_sup_decor_openmp.h" +#include "bli_l3_sup_decor_pthreads.h" + +#endif + diff --git a/frame/thread/bli_l3_sup_decor_openmp.c b/frame/thread/bli_l3_sup_decor_openmp.c new file mode 100644 index 0000000000..1db9514fd4 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor_openmp.c @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy function bli_l3_sup_thread_entry(), which is needed in the +// pthreads version, so that when building Windows DLLs (with OpenMP enabled +// or no multithreading) we don't risk having an unresolved symbol. +void* bli_l3_sup_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +err_t bli_l3_sup_thread_decorator + ( + l3supint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + return BLIS_SUCCESS; +} + +#endif + diff --git a/frame/thread/bli_l3_sup_decor_openmp.h b/frame/thread/bli_l3_sup_decor_openmp.h new file mode 100644 index 0000000000..1d1097a822 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_DECOR_OPENMP_H +#define BLIS_L3_SUP_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/frame/thread/bli_l3_sup_decor_pthreads.c b/frame/thread/bli_l3_sup_decor_pthreads.c new file mode 100644 index 0000000000..dade71a035 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor_pthreads.c @@ -0,0 +1,218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3supint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point for additional threads +void* bli_l3_sup_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3supint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +err_t bli_l3_sup_thread_decorator + ( + l3supint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t r_val; + + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bli_l3_sup_thread_entry, &datas[tid] ); + else + bli_l3_sup_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); + + return BLIS_SUCCESS; +} + +#endif + diff --git a/frame/thread/bli_l3_sup_decor_pthreads.h b/frame/thread/bli_l3_sup_decor_pthreads.h new file mode 100644 index 0000000000..1362b40355 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_DECOR_PTHREADS_H +#define BLIS_L3_SUP_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bli_l3_sup_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/frame/thread/bli_l3_sup_decor_single.c b/frame/thread/bli_l3_sup_decor_single.c new file mode 100644 index 0000000000..a87af41032 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor_single.c @@ -0,0 +1,145 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +err_t bli_l3_sup_thread_decorator + ( + l3supint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_pba_rntm_set_pba( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + return BLIS_SUCCESS; + +} + +#endif + diff --git a/frame/thread/bli_l3_sup_decor_single.h b/frame/thread/bli_l3_sup_decor_single.h new file mode 100644 index 0000000000..418c3814c3 --- /dev/null +++ b/frame/thread/bli_l3_sup_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_DECOR_SINGLE_H +#define BLIS_L3_SUP_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/frame/thread/bli_pthread.c b/frame/thread/bli_pthread.c index 03b44a5851..a09935661b 100644 --- a/frame/thread/bli_pthread.c +++ b/frame/thread/bli_pthread.c @@ -6,7 +6,7 @@ Copyright (C) 2018, Southern Methodist University Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,89 +36,96 @@ #include "blis.h" -#include - -#if defined(_MSC_VER) +#if defined(BLIS_DISABLE_SYSTEM) // This branch defines a pthread-like API, bli_pthread_*(), and implements it -// in terms of Windows API calls. +// in terms of "dummy" code that doesn't depend on POSIX threads or any other +// threading mechanism. See issue #454 to see the use case that prompted this +// feature. +// NOTE: THIS CODE DOES NOT IMPLEMENT THREADING AND IS NOT THREAD-SAFE! -int bli_pthread_mutex_init +// -- pthread_create(), pthread_join() -- + +int bli_pthread_create ( - bli_pthread_mutex_t* mutex, - const bli_pthread_mutexattr_t* attr + bli_pthread_t* thread, + const bli_pthread_attr_t* attr, + void* (*start_routine)(void*), + void* arg ) { - if ( attr ) return EINVAL; - InitializeSRWLock( mutex ); + //return pthread_create( thread, attr, start_routine, arg ); + start_routine( arg ); return 0; } -int bli_pthread_mutex_destroy +int bli_pthread_join ( - bli_pthread_mutex_t* mutex + bli_pthread_t thread, + void** retval ) { + //return pthread_join( thread, retval ); return 0; } -int bli_pthread_mutex_lock +// -- pthread_mutex_*() -- + +int bli_pthread_mutex_init ( - bli_pthread_mutex_t* mutex + bli_pthread_mutex_t* mutex, + const bli_pthread_mutexattr_t* attr ) { - AcquireSRWLockExclusive( mutex ); + //return pthread_mutex_init( mutex, attr ); return 0; } -int bli_pthread_mutex_trylock +int bli_pthread_mutex_destroy ( bli_pthread_mutex_t* mutex ) { - return TryAcquireSRWLockExclusive( mutex ) ? 0 : EBUSY; + //return pthread_mutex_destroy( mutex ); + return 0; } -int bli_pthread_mutex_unlock +int bli_pthread_mutex_lock ( bli_pthread_mutex_t* mutex ) { - ReleaseSRWLockExclusive( mutex ); + //return pthread_mutex_lock( mutex ); return 0; } -static BOOL bli_init_once_wrapper +int bli_pthread_mutex_trylock ( - bli_pthread_once_t* once, - void* param, - void** context + bli_pthread_mutex_t* mutex ) { - ( void )once; - ( void )context; - typedef void (*callback)( void ); - ((callback)param)(); - return TRUE; + //return pthread_mutex_trylock( mutex ); + return 0; } -void bli_pthread_once +int bli_pthread_mutex_unlock ( - bli_pthread_once_t* once, - void (*init)(void) + bli_pthread_mutex_t* mutex ) { - InitOnceExecuteOnce( once, bli_init_once_wrapper, init, NULL ); + //return pthread_mutex_unlock( mutex ); + return 0; } +// -- pthread_cond_*() -- + int bli_pthread_cond_init ( bli_pthread_cond_t* cond, const bli_pthread_condattr_t* attr ) { - if ( attr ) return EINVAL; - InitializeConditionVariable( cond ); + //return pthread_cond_init( cond, attr ); return 0; } @@ -127,7 +134,7 @@ int bli_pthread_cond_destroy bli_pthread_cond_t* cond ) { - ( void )cond; + //return pthread_cond_destroy( cond ); return 0; } @@ -137,7 +144,7 @@ int bli_pthread_cond_wait bli_pthread_mutex_t* mutex ) { - if ( !SleepConditionVariableSRW( cond, mutex, INFINITE, 0 ) ) return EAGAIN; + //return pthread_cond_wait( cond, mutex ); return 0; } @@ -146,10 +153,64 @@ int bli_pthread_cond_broadcast bli_pthread_cond_t* cond ) { - WakeAllConditionVariable( cond ); + //return pthread_cond_broadcast( cond ); return 0; } +// -- pthread_once() -- + +void bli_pthread_once + ( + bli_pthread_once_t* once, + void (*init)(void) + ) +{ + //pthread_once( once, init ); + init(); +} + +#if 0 +// NOTE: This part of the API is disabled because (1) we don't actually need +// _self() or _equal() yet, and (2) when we do try to include these functions, +// AppVeyor for some reason fails on all the Windows/clang builds with the +// error: +// libblis.a(bli_pthread.o) : error LNK2019: unresolved external symbol +// __imp_CompareObjectHandles referenced in function bli_pthread_equal + +// -- pthread_self() -- + +bli_pthread_t bli_pthread_self + ( + void + ) +{ + return 0; +} + +// -- pthread_equal() -- + +int bli_pthread_equal + ( + bli_pthread_t t1, + bli_pthread_t t2 + ) +{ + // We don't bother comparing t1 and t2 since we must, by definition, be + // executing the same thread if there is not threading mechanism on the + // system. + return 1; +} +#endif + +#elif defined(_MSC_VER) // !defined(BLIS_DISABLE_SYSTEM) + +#include + +// This branch defines a pthread-like API, bli_pthread_*(), and implements it +// in terms of Windows API calls. + +// -- pthread_create(), pthread_join() -- + typedef struct { void* (*start_routine)( void* ); @@ -194,7 +255,158 @@ int bli_pthread_join return 0; } -#else // !defined(_MSC_VER) +// -- pthread_mutex_*() -- + +int bli_pthread_mutex_init + ( + bli_pthread_mutex_t* mutex, + const bli_pthread_mutexattr_t* attr + ) +{ + if ( attr ) return EINVAL; + InitializeSRWLock( mutex ); + return 0; +} + +int bli_pthread_mutex_destroy + ( + bli_pthread_mutex_t* mutex + ) +{ + return 0; +} + +int bli_pthread_mutex_lock + ( + bli_pthread_mutex_t* mutex + ) +{ + AcquireSRWLockExclusive( mutex ); + return 0; +} + +int bli_pthread_mutex_trylock + ( + bli_pthread_mutex_t* mutex + ) +{ + return TryAcquireSRWLockExclusive( mutex ) ? 0 : EBUSY; +} + +int bli_pthread_mutex_unlock + ( + bli_pthread_mutex_t* mutex + ) +{ + ReleaseSRWLockExclusive( mutex ); + return 0; +} + +// -- pthread_cond_*() -- + +int bli_pthread_cond_init + ( + bli_pthread_cond_t* cond, + const bli_pthread_condattr_t* attr + ) +{ + if ( attr ) return EINVAL; + InitializeConditionVariable( cond ); + return 0; +} + +int bli_pthread_cond_destroy + ( + bli_pthread_cond_t* cond + ) +{ + ( void )cond; + return 0; +} + +int bli_pthread_cond_wait + ( + bli_pthread_cond_t* cond, + bli_pthread_mutex_t* mutex + ) +{ + if ( !SleepConditionVariableSRW( cond, mutex, INFINITE, 0 ) ) return EAGAIN; + return 0; +} + +int bli_pthread_cond_broadcast + ( + bli_pthread_cond_t* cond + ) +{ + WakeAllConditionVariable( cond ); + return 0; +} + +// -- pthread_once() -- + +static BOOL bli_init_once_wrapper + ( + bli_pthread_once_t* once, + void* param, + void** context + ) +{ + ( void )once; + ( void )context; + typedef void (*callback)( void ); + ((callback)param)(); + return TRUE; +} + +void bli_pthread_once + ( + bli_pthread_once_t* once, + void (*init)(void) + ) +{ + InitOnceExecuteOnce( once, bli_init_once_wrapper, init, NULL ); +} + +#if 0 +// NOTE: This part of the API is disabled because (1) we don't actually need +// _self() or _equal() yet, and (2) when we do try to include these functions, +// AppVeyor for some reason fails on all the Windows/clang builds with the +// error: +// libblis.a(bli_pthread.o) : error LNK2019: unresolved external symbol +// __imp_CompareObjectHandles referenced in function bli_pthread_equal + +// -- pthread_self() -- + +bli_pthread_t bli_pthread_self + ( + void + ) +{ + bli_pthread_t t; + + // Note: BLIS will only ever use bli_pthread_self() in conjunction with + // bli_pthread_equal(), and thus setting the .retval field is unnecessary. + // Despite this, we set it to NULL anyway. + t.handle = GetCurrentThread(); + t.retval = NULL; + + return t; +} + +// -- pthread_equal() -- + +int bli_pthread_equal + ( + bli_pthread_t t1, + bli_pthread_t t2 + ) +{ + return ( int )CompareObjectHandles( t1.handle, t2.handle ); +} +#endif + +#else // !defined(BLIS_DISABLE_SYSTEM) && !defined(_MSC_VER) // This branch defines a pthreads-like API, bli_pthreads_*(), and implements it // in terms of the corresponding pthreads_*() types, macros, and function calls. @@ -314,12 +526,77 @@ void bli_pthread_once pthread_once( once, init ); } -#endif // _MSC_VER +#if 0 +// NOTE: This part of the API is disabled because (1) we don't actually need +// _self() or _equal() yet, and (2) when we do try to include these functions, +// AppVeyor for some reason fails on all the Windows/clang builds with the +// error: +// libblis.a(bli_pthread.o) : error LNK2019: unresolved external symbol +// __imp_CompareObjectHandles referenced in function bli_pthread_equal + +// -- pthread_self() -- + +bli_pthread_t bli_pthread_self + ( + void + ) +{ + return pthread_self(); +} + +// -- pthread_equal() -- + +int bli_pthread_equal + ( + bli_pthread_t t1, + bli_pthread_t t2 + ) +{ + return pthread_equal( t1, t2 ); +} +#endif + +#endif // !defined(BLIS_DISABLE_SYSTEM) && !defined(_MSC_VER) + + // -- pthread_barrier_*() -- -#if defined(__APPLE__) || defined(_MSC_VER) +#if defined(BLIS_DISABLE_SYSTEM) + +int bli_pthread_barrier_init + ( + bli_pthread_barrier_t* barrier, + const bli_pthread_barrierattr_t* attr, + unsigned int count + ) +{ + //return pthread_barrier_init( barrier, attr, count ); + return 0; +} + +int bli_pthread_barrier_destroy + ( + bli_pthread_barrier_t* barrier + ) +{ + //return pthread_barrier_destroy( barrier ); + return 0; +} + +int bli_pthread_barrier_wait + ( + bli_pthread_barrier_t* barrier + ) +{ + //return pthread_barrier_wait( barrier ); + return 0; +} + +#elif defined(__APPLE__) || defined(_MSC_VER) // !defined(BLIS_DISABLE_SYSTEM) + +#include // For OS X and Windows, we define barriers ourselves in terms of the rest // of the API, though for slightly different reasons: For Windows, we must @@ -382,7 +659,7 @@ int bli_pthread_barrier_wait } } -#else // !( defined(__APPLE__) || defined(_MSC_VER) ) +#else // !defined(BLIS_DISABLE_SYSTEM) && !defined(__APPLE__) && !defined(_MSC_VER) // Linux environments implement the pthread_barrier* sub-API. So, if we're // on Linux, we can simply call those functions, just as we did before for @@ -414,4 +691,5 @@ int bli_pthread_barrier_wait return pthread_barrier_wait( barrier ); } -#endif // defined(__APPLE__) || defined(_MSC_VER) +#endif + diff --git a/frame/thread/bli_pthread.h b/frame/thread/bli_pthread.h index 56ede89b50..be786aa39c 100644 --- a/frame/thread/bli_pthread.h +++ b/frame/thread/bli_pthread.h @@ -36,113 +36,53 @@ #ifndef BLIS_PTHREAD_H #define BLIS_PTHREAD_H -#if defined(_MSC_VER) +// -- Type and macro definitions ----------------------------------------------- -// This branch defines a pthread-like API, bli_pthread_*(), and implements it -// in terms of Windows API calls. - -// -- pthread_mutex_*() -- - -typedef SRWLOCK bli_pthread_mutex_t; -typedef void bli_pthread_mutexattr_t; - -#define BLIS_PTHREAD_MUTEX_INITIALIZER SRWLOCK_INIT - -BLIS_EXPORT_BLIS int bli_pthread_mutex_init - ( - bli_pthread_mutex_t* mutex, - const bli_pthread_mutexattr_t* attr - ); - -BLIS_EXPORT_BLIS int bli_pthread_mutex_destroy - ( - bli_pthread_mutex_t* mutex - ); +#if defined(BLIS_DISABLE_SYSTEM) -BLIS_EXPORT_BLIS int bli_pthread_mutex_lock - ( - bli_pthread_mutex_t* mutex - ); - -BLIS_EXPORT_BLIS int bli_pthread_mutex_trylock - ( - bli_pthread_mutex_t* mutex - ); - -BLIS_EXPORT_BLIS int bli_pthread_mutex_unlock - ( - bli_pthread_mutex_t* mutex - ); - -// -- pthread_once_*() -- - -typedef INIT_ONCE bli_pthread_once_t; - -#define BLIS_PTHREAD_ONCE_INIT INIT_ONCE_STATIC_INIT - -BLIS_EXPORT_BLIS void bli_pthread_once - ( - bli_pthread_once_t* once, - void (*init)(void) - ); +// This branch defines a pthread-like API, bli_pthread_*(), and implements it +// in terms of "dummy" code that doesn't depend on POSIX threads or any other +// threading mechanism. See issue #454 to see the use case that prompted this +// feature. +// NOTE: THIS CODE DOES NOT IMPLEMENT THREADING AND IS NOT THREAD-SAFE! -// -- pthread_cond_*() -- +// -- pthread types -- -typedef CONDITION_VARIABLE bli_pthread_cond_t; -typedef void bli_pthread_condattr_t; +typedef int bli_pthread_t; +typedef int bli_pthread_attr_t; +typedef int bli_pthread_mutex_t; +typedef int bli_pthread_mutexattr_t; +typedef int bli_pthread_cond_t; +typedef int bli_pthread_condattr_t; +typedef int bli_pthread_once_t; -#define BLIS_PTHREAD_COND_INITIALIZER CONDITION_VARIABLE_INIT +typedef int bli_pthread_barrier_t; +typedef int bli_pthread_barrierattr_t; -BLIS_EXPORT_BLIS int bli_pthread_cond_init - ( - bli_pthread_cond_t* cond, - const bli_pthread_condattr_t* attr - ); +// -- pthreads macros -- -BLIS_EXPORT_BLIS int bli_pthread_cond_destroy - ( - bli_pthread_cond_t* cond - ); +#define BLIS_PTHREAD_MUTEX_INITIALIZER 0 +#define BLIS_PTHREAD_COND_INITIALIZER 0 +#define BLIS_PTHREAD_ONCE_INIT 0 -BLIS_EXPORT_BLIS int bli_pthread_cond_wait - ( - bli_pthread_cond_t* cond, - bli_pthread_mutex_t* mutex - ); +#elif defined(_MSC_VER) // !defined(BLIS_DISABLE_SYSTEM) -BLIS_EXPORT_BLIS int bli_pthread_cond_broadcast - ( - bli_pthread_cond_t* cond - ); +// This branch defines a pthread-like API, bli_pthread_*(), and implements it +// in terms of Windows API calls. -// -- pthread_create(), pthread_join() -- +// -- pthread types -- typedef struct { HANDLE handle; void* retval; } bli_pthread_t; - typedef void bli_pthread_attr_t; - -BLIS_EXPORT_BLIS int bli_pthread_create - ( - bli_pthread_t* thread, - const bli_pthread_attr_t* attr, - void* (*start_routine)(void*), - void* arg - ); - -BLIS_EXPORT_BLIS int bli_pthread_join - ( - bli_pthread_t thread, - void** retval - ); - -// -- pthread_barrier_*() -- - -typedef void bli_pthread_barrierattr_t; - +typedef SRWLOCK bli_pthread_mutex_t; +typedef void bli_pthread_mutexattr_t; +typedef CONDITION_VARIABLE bli_pthread_cond_t; +typedef void bli_pthread_condattr_t; +typedef INIT_ONCE bli_pthread_once_t; typedef struct { bli_pthread_mutex_t mutex; @@ -150,25 +90,15 @@ typedef struct int count; int tripCount; } bli_pthread_barrier_t; +typedef void bli_pthread_barrierattr_t; -BLIS_EXPORT_BLIS int bli_pthread_barrier_init - ( - bli_pthread_barrier_t* barrier, - const bli_pthread_barrierattr_t* attr, - unsigned int count - ); - -BLIS_EXPORT_BLIS int bli_pthread_barrier_destroy - ( - bli_pthread_barrier_t* barrier - ); +// -- pthreads macros -- -BLIS_EXPORT_BLIS int bli_pthread_barrier_wait - ( - bli_pthread_barrier_t* barrier - ); +#define BLIS_PTHREAD_MUTEX_INITIALIZER SRWLOCK_INIT +#define BLIS_PTHREAD_ONCE_INIT INIT_ONCE_STATIC_INIT +#define BLIS_PTHREAD_COND_INITIALIZER CONDITION_VARIABLE_INIT -#else // !defined(_MSC_VER) +#else // !defined(BLIS_DISABLE_SYSTEM) && !defined(_MSC_VER) #include @@ -177,13 +107,13 @@ BLIS_EXPORT_BLIS int bli_pthread_barrier_wait // -- pthread types -- -typedef pthread_t bli_pthread_t; -typedef pthread_attr_t bli_pthread_attr_t; -typedef pthread_mutex_t bli_pthread_mutex_t; -typedef pthread_mutexattr_t bli_pthread_mutexattr_t; -typedef pthread_cond_t bli_pthread_cond_t; -typedef pthread_condattr_t bli_pthread_condattr_t; -typedef pthread_once_t bli_pthread_once_t; +typedef pthread_t bli_pthread_t; +typedef pthread_attr_t bli_pthread_attr_t; +typedef pthread_mutex_t bli_pthread_mutex_t; +typedef pthread_mutexattr_t bli_pthread_mutexattr_t; +typedef pthread_cond_t bli_pthread_cond_t; +typedef pthread_condattr_t bli_pthread_condattr_t; +typedef pthread_once_t bli_pthread_once_t; #if defined(__APPLE__) @@ -194,10 +124,10 @@ typedef void bli_pthread_barrierattr_t; typedef struct { - bli_pthread_mutex_t mutex; - bli_pthread_cond_t cond; - int count; - int tripCount; + bli_pthread_mutex_t mutex; + bli_pthread_cond_t cond; + int count; + int tripCount; } bli_pthread_barrier_t; #else @@ -217,6 +147,10 @@ typedef pthread_barrierattr_t bli_pthread_barrierattr_t; #define BLIS_PTHREAD_COND_INITIALIZER PTHREAD_COND_INITIALIZER #define BLIS_PTHREAD_ONCE_INIT PTHREAD_ONCE_INIT +#endif + +// -- Function definitions ----------------------------------------------------- + // -- pthread_create(), pthread_join() -- BLIS_EXPORT_BLIS int bli_pthread_create @@ -285,7 +219,7 @@ BLIS_EXPORT_BLIS int bli_pthread_cond_broadcast bli_pthread_cond_t* cond ); -// -- pthread_once_*() -- +// -- pthread_once() -- BLIS_EXPORT_BLIS void bli_pthread_once ( @@ -293,6 +227,30 @@ BLIS_EXPORT_BLIS void bli_pthread_once void (*init)(void) ); +#if 0 +// NOTE: This part of the API is disabled because (1) we don't actually need +// _self() or _equal() yet, and (2) when we do try to include these functions, +// AppVeyor for some reason fails on all the Windows/clang builds with the +// error: +// libblis.a(bli_pthread.o) : error LNK2019: unresolved external symbol +// __imp_CompareObjectHandles referenced in function bli_pthread_equal + +// -- pthread_self() -- + +BLIS_EXPORT_BLIS bli_pthread_t bli_pthread_self + ( + void + ); + +// -- pthread_equal() -- + +BLIS_EXPORT_BLIS int bli_pthread_equal + ( + bli_pthread_t t1, + bli_pthread_t t2 + ); +#endif + // -- pthread_barrier_*() -- BLIS_EXPORT_BLIS int bli_pthread_barrier_init @@ -312,6 +270,4 @@ BLIS_EXPORT_BLIS int bli_pthread_barrier_wait bli_pthread_barrier_t* barrier ); -#endif // _MSC_VER - #endif // BLIS_PTHREAD_H diff --git a/frame/thread/bli_thrcomm.c b/frame/thread/bli_thrcomm.c index c9698050c2..ef46a7ad43 100644 --- a/frame/thread/bli_thrcomm.c +++ b/frame/thread/bli_thrcomm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,7 +86,7 @@ void bli_thrcomm_barrier_atomic( dim_t t_id, thrcomm_t* comm ) // fact, if everything else is working, a binary variable is sufficient, // which is what we do here (i.e., 0 is incremented to 1, which is then // decremented back to 0, and so forth). - bool_t orig_sense = __atomic_load_n( &comm->barrier_sense, __ATOMIC_RELAXED ); + gint_t orig_sense = __atomic_load_n( &comm->barrier_sense, __ATOMIC_RELAXED ); // Register ourselves (the current thread) as having arrived by // incrementing the barrier_threads_arrived variable. We must perform diff --git a/frame/thread/bli_thrcomm.h b/frame/thread/bli_thrcomm.h index 04bceae2a5..d0ffb13461 100644 --- a/frame/thread/bli_thrcomm.h +++ b/frame/thread/bli_thrcomm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,13 +43,9 @@ #include "bli_thrcomm_pthreads.h" -// thread entry point prototype. -void* bli_l3_thread_entry( void* data_void ); - - // thrcomm_t query (field only) -static dim_t bli_thrcomm_num_threads( thrcomm_t* comm ) +BLIS_INLINE dim_t bli_thrcomm_num_threads( thrcomm_t* comm ) { return comm->n_threads; } @@ -60,8 +56,9 @@ thrcomm_t* bli_thrcomm_create( rntm_t* rntm, dim_t n_threads ); void bli_thrcomm_free( rntm_t* rntm, thrcomm_t* comm ); void bli_thrcomm_init( dim_t n_threads, thrcomm_t* comm ); void bli_thrcomm_cleanup( thrcomm_t* comm ); -void bli_thrcomm_barrier( dim_t thread_id, thrcomm_t* comm ); -void* bli_thrcomm_bcast( dim_t inside_id, void* to_send, thrcomm_t* comm ); + +BLIS_EXPORT_BLIS void bli_thrcomm_barrier( dim_t thread_id, thrcomm_t* comm ); +BLIS_EXPORT_BLIS void* bli_thrcomm_bcast( dim_t inside_id, void* to_send, thrcomm_t* comm ); void bli_thrcomm_barrier_atomic( dim_t thread_id, thrcomm_t* comm ); diff --git a/frame/thread/bli_thrcomm_openmp.c b/frame/thread/bli_thrcomm_openmp.c index 05cfa610a0..9bb35ea31a 100644 --- a/frame/thread/bli_thrcomm_openmp.c +++ b/frame/thread/bli_thrcomm_openmp.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,7 +87,7 @@ void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) #if 0 if ( comm == NULL || comm->n_threads == 1 ) return; - bool_t my_sense = comm->barrier_sense; + gint_t my_sense = comm->barrier_sense; dim_t my_threads_arrived; _Pragma( "omp atomic capture" ) @@ -100,7 +100,7 @@ void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) } else { - volatile bool_t* listener = &comm->barrier_sense; + volatile gint_t* listener = &comm->barrier_sense; while ( *listener == my_sense ) {} } #endif @@ -111,17 +111,21 @@ void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) void bli_thrcomm_init( dim_t n_threads, thrcomm_t* comm ) { + err_t r_val; + if ( comm == NULL ) return; comm->sent_object = NULL; comm->n_threads = n_threads; - comm->barriers = bli_malloc_intl( sizeof( barrier_t* ) * n_threads ); + comm->barriers = bli_malloc_intl( sizeof( barrier_t* ) * n_threads, &r_val ); bli_thrcomm_tree_barrier_create( n_threads, BLIS_TREE_BARRIER_ARITY, comm->barriers, 0 ); } //Tree barrier used for Intel Xeon Phi barrier_t* bli_thrcomm_tree_barrier_create( int num_threads, int arity, barrier_t** leaves, int leaf_index ) { - barrier_t* me = bli_malloc_intl( sizeof(barrier_t) ); + err_t r_val; + + barrier_t* me = bli_malloc_intl( sizeof( barrier_t ), &r_val ); me->dad = NULL; me->signal = 0; @@ -214,212 +218,5 @@ void bli_thrcomm_tree_barrier( barrier_t* barack ) #endif - -// Define a dummy function bli_l3_thread_entry(), which is needed in the -// pthreads version, so that when building Windows DLLs (with OpenMP enabled -// or no multithreading) we don't risk having an unresolved symbol. -void* bli_l3_thread_entry( void* data_void ) { return NULL; } - -//#define PRINT_THRINFO - -void bli_l3_thread_decorator - ( - l3int_t func, - opid_t family, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - // This is part of a hack to support mixed domain in bli_gemm_front(). - // Sometimes we need to specify a non-standard schema for A and B, and - // we decided to transmit them via the schema field in the obj_t's - // rather than pass them in as function parameters. Once the values - // have been read, we immediately reset them back to their expected - // values for unpacked objects. - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); - - // Query the total number of threads from the rntm_t object. - const dim_t n_threads = bli_rntm_num_threads( rntm ); - - #ifdef PRINT_THRINFO - thrinfo_t** threads = bli_malloc_intl( n_threads * sizeof( thrinfo_t* ) ); - #endif - - // NOTE: The sba was initialized in bli_init(). - - // Check out an array_t from the small block allocator. This is done - // with an internal lock to ensure only one application thread accesses - // the sba at a time. bli_sba_checkout_array() will also automatically - // resize the array_t, if necessary. - array_t* restrict array = bli_sba_checkout_array( n_threads ); - - // Access the pool_t* for thread 0 and embed it into the rntm. We do - // this up-front only so that we have the rntm_t.sba_pool field - // initialized and ready for the global communicator creation below. - bli_sba_rntm_set_pool( 0, array, rntm ); - - // Set the packing block allocator field of the rntm. This will be - // inherited by all of the child threads when they make local copies of - // the rntm below. - bli_membrk_rntm_set_membrk( rntm ); - - // Allocate a global communicator for the root thrinfo_t structures. - thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); - - - _Pragma( "omp parallel num_threads(n_threads)" ) - { - // Create a thread-local copy of the master thread's rntm_t. This is - // necessary since we want each thread to be able to track its own - // small block pool_t as it executes down the function stack. - rntm_t rntm_l = *rntm; - rntm_t* restrict rntm_p = &rntm_l; - - // Query the thread's id from OpenMP. - const dim_t tid = omp_get_thread_num(); - - // Check for a somewhat obscure OpenMP thread-mistmatch issue. - bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); - - // Use the thread id to access the appropriate pool_t* within the - // array_t, and use it to set the sba_pool field within the rntm_t. - // If the pool_t* element within the array_t is NULL, it will first - // be allocated/initialized. - bli_sba_rntm_set_pool( tid, array, rntm_p ); - - - obj_t a_t, b_t, c_t; - cntl_t* cntl_use; - thrinfo_t* thread; - - // Alias thread-local copies of A, B, and C. These will be the objects - // we pass down the algorithmic function stack. Making thread-local - // alaises is highly recommended in case a thread needs to change any - // of the properties of an object without affecting other threads' - // objects. - bli_obj_alias_to( a, &a_t ); - bli_obj_alias_to( b, &b_t ); - bli_obj_alias_to( c, &c_t ); - - // Create a default control tree for the operation, if needed. - bli_l3_cntl_create_if( family, schema_a, schema_b, - &a_t, &b_t, &c_t, rntm_p, cntl, &cntl_use ); - - // Create the root node of the current thread's thrinfo_t structure. - bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); - -#if 1 - func - ( - alpha, - &a_t, - &b_t, - beta, - &c_t, - cntx, - rntm_p, - cntl_use, - thread - ); -#else - bli_thrinfo_grow_tree - ( - rntm_p, - cntl_use, - thread - ); #endif - // Free the thread's local control tree. - bli_l3_cntl_free( rntm_p, cntl_use, thread ); - - #ifdef PRINT_THRINFO - threads[tid] = thread; - #else - // Free the current thread's thrinfo_t structure. - bli_l3_thrinfo_free( rntm_p, thread ); - #endif - } - - // We shouldn't free the global communicator since it was already freed - // by the global communicator's chief thread in bli_l3_thrinfo_free() - // (called above). - - #ifdef PRINT_THRINFO - if ( family != BLIS_TRSM ) bli_l3_thrinfo_print_gemm_paths( threads ); - else bli_l3_thrinfo_print_trsm_paths( threads ); - exit(1); - #endif - - // Check the array_t back into the small block allocator. Similar to the - // check-out, this is done using a lock embedded within the sba to ensure - // mutual exclusion. - bli_sba_checkin_array( array ); -} - -// ----------------------------------------------------------------------------- - -void bli_l3_thread_decorator_thread_check - ( - dim_t n_threads, - dim_t tid, - thrcomm_t* gl_comm, - rntm_t* rntm - ) -{ - dim_t n_threads_real = omp_get_num_threads(); - - // Check if the number of OpenMP threads created within this parallel - // region is different from the number of threads that were requested - // of BLIS. This inequality may trigger when, for example, the - // following conditions are satisfied: - // - an application is executing an OpenMP parallel region in which - // BLIS is invoked, - // - BLIS is configured for multithreading via OpenMP, - // - OMP_NUM_THREADS = t > 1, - // - the number of threads requested of BLIS (regardless of method) - // is p <= t, - // - OpenMP nesting is disabled. - // In this situation, the application spawns t threads. Each application - // thread calls gemm (for example). Each gemm will attempt to spawn p - // threads via OpenMP. However, since nesting is disabled, the OpenMP - // implementation finds that t >= p threads are already spawned, and - // thus it doesn't spawn *any* additional threads for each gemm. - if ( n_threads_real != n_threads ) - { - // If the number of threads active in the current region is not - // equal to the number requested of BLIS, we then only continue - // if the number of threads in the current region is 1. If, for - // example, BLIS requested 4 threads but only got 3, then we - // abort(). - //if ( tid == 0 ) - //{ - if ( n_threads_real != 1 ) - { - bli_print_msg( "A different number of threads was " - "created than was requested.", - __FILE__, __LINE__ ); - bli_abort(); - } - - //n_threads = 1; // not needed since it has no effect? - bli_thrcomm_init( 1, gl_comm ); - bli_rntm_set_num_threads_only( 1, rntm ); - bli_rntm_set_ways_only( 1, 1, 1, 1, 1, rntm ); - //} - - // Synchronize all threads and continue. - _Pragma( "omp barrier" ) - } -} - -#endif diff --git a/frame/thread/bli_thrcomm_openmp.h b/frame/thread/bli_thrcomm_openmp.h index 4b8956a14c..3abfd0a413 100644 --- a/frame/thread/bli_thrcomm_openmp.h +++ b/frame/thread/bli_thrcomm_openmp.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -64,8 +64,14 @@ struct thrcomm_s void* sent_object; dim_t n_threads; - //volatile bool_t barrier_sense; - bool_t barrier_sense; + // NOTE: barrier_sense was originally a gint_t-based bool_t, but upon + // redefining bool_t as bool we discovered that some gcc __atomic built-ins + // don't allow the use of bool for the variables being operated upon. + // (Specifically, this was observed of __atomic_fetch_xor(), but it likely + // applies to all other related built-ins.) Thus, we get around this by + // redefining barrier_sense as a gint_t. + //volatile gint_t barrier_sense; + gint_t barrier_sense; dim_t barrier_threads_arrived; }; #endif @@ -79,14 +85,6 @@ void bli_thrcomm_tree_barrier_free( barrier_t* barrier ); void bli_thrcomm_tree_barrier( barrier_t* barack ); #endif -void bli_l3_thread_decorator_thread_check - ( - dim_t n_threads, - dim_t tid, - thrcomm_t* gl_comm, - rntm_t* rntm - ); - #endif #endif diff --git a/frame/thread/bli_thrcomm_pthreads.c b/frame/thread/bli_thrcomm_pthreads.c index 975c5eb886..d0896f94df 100644 --- a/frame/thread/bli_thrcomm_pthreads.c +++ b/frame/thread/bli_thrcomm_pthreads.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -93,34 +93,20 @@ void bli_thrcomm_init( dim_t n_threads, thrcomm_t* comm ) comm->n_threads = n_threads; comm->barrier_sense = 0; comm->barrier_threads_arrived = 0; - -//#ifdef BLIS_USE_PTHREAD_MUTEX -// bli_pthread_mutex_init( &comm->mutex, NULL ); -//#endif } void bli_thrcomm_cleanup( thrcomm_t* comm ) { -//#ifdef BLIS_USE_PTHREAD_MUTEX -// if ( comm == NULL ) return; -// bli_pthread_mutex_destroy( &comm->mutex ); -//#endif } void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) { #if 0 if ( comm == NULL || comm->n_threads == 1 ) return; - bool_t my_sense = comm->sense; + bool my_sense = comm->sense; dim_t my_threads_arrived; -#ifdef BLIS_USE_PTHREAD_MUTEX - bli_pthread_mutex_lock( &comm->mutex ); - my_threads_arrived = ++(comm->threads_arrived); - bli_pthread_mutex_unlock( &comm->mutex ); -#else my_threads_arrived = __sync_add_and_fetch(&(comm->threads_arrived), 1); -#endif if ( my_threads_arrived == comm->n_threads ) { @@ -129,7 +115,7 @@ void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) } else { - volatile bool_t* listener = &comm->sense; + volatile bool* listener = &comm->sense; while( *listener == my_sense ) {} } #endif @@ -138,217 +124,5 @@ void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) #endif - -// A data structure to assist in passing operands to additional threads. -typedef struct thread_data -{ - l3int_t func; - opid_t family; - pack_t schema_a; - pack_t schema_b; - obj_t* alpha; - obj_t* a; - obj_t* b; - obj_t* beta; - obj_t* c; - cntx_t* cntx; - rntm_t* rntm; - cntl_t* cntl; - dim_t tid; - thrcomm_t* gl_comm; - array_t* array; -} thread_data_t; - -// Entry point for additional threads -void* bli_l3_thread_entry( void* data_void ) -{ - thread_data_t* data = data_void; - - l3int_t func = data->func; - opid_t family = data->family; - pack_t schema_a = data->schema_a; - pack_t schema_b = data->schema_b; - obj_t* alpha = data->alpha; - obj_t* a = data->a; - obj_t* b = data->b; - obj_t* beta = data->beta; - obj_t* c = data->c; - cntx_t* cntx = data->cntx; - rntm_t* rntm = data->rntm; - cntl_t* cntl = data->cntl; - dim_t tid = data->tid; - array_t* array = data->array; - thrcomm_t* gl_comm = data->gl_comm; - - // Create a thread-local copy of the master thread's rntm_t. This is - // necessary since we want each thread to be able to track its own - // small block pool_t as it executes down the function stack. - rntm_t rntm_l = *rntm; - rntm_t* restrict rntm_p = &rntm_l; - - // Use the thread id to access the appropriate pool_t* within the - // array_t, and use it to set the sba_pool field within the rntm_t. - // If the pool_t* element within the array_t is NULL, it will first - // be allocated/initialized. - bli_sba_rntm_set_pool( tid, array, rntm_p ); - - obj_t a_t, b_t, c_t; - cntl_t* cntl_use; - thrinfo_t* thread; - - // Alias thread-local copies of A, B, and C. These will be the objects - // we pass down the algorithmic function stack. Making thread-local - // alaises is highly recommended in case a thread needs to change any - // of the properties of an object without affecting other threads' - // objects. - bli_obj_alias_to( a, &a_t ); - bli_obj_alias_to( b, &b_t ); - bli_obj_alias_to( c, &c_t ); - - // Create a default control tree for the operation, if needed. - bli_l3_cntl_create_if( family, schema_a, schema_b, - &a_t, &b_t, &c_t, rntm_p, cntl, &cntl_use ); - - // Create the root node of the current thread's thrinfo_t structure. - bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); - - func - ( - alpha, - &a_t, - &b_t, - beta, - &c_t, - cntx, - rntm_p, - cntl_use, - thread - ); - - // Free the thread's local control tree. - bli_l3_cntl_free( rntm_p, cntl_use, thread ); - - // Free the current thread's thrinfo_t structure. - bli_l3_thrinfo_free( rntm_p, thread ); - - return NULL; -} - -void bli_l3_thread_decorator - ( - l3int_t func, - opid_t family, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - // This is part of a hack to support mixed domain in bli_gemm_front(). - // Sometimes we need to specify a non-standard schema for A and B, and - // we decided to transmit them via the schema field in the obj_t's - // rather than pass them in as function parameters. Once the values - // have been read, we immediately reset them back to their expected - // values for unpacked objects. - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); - - // Query the total number of threads from the context. - const dim_t n_threads = bli_rntm_num_threads( rntm ); - - // NOTE: The sba was initialized in bli_init(). - - // Check out an array_t from the small block allocator. This is done - // with an internal lock to ensure only one application thread accesses - // the sba at a time. bli_sba_checkout_array() will also automatically - // resize the array_t, if necessary. - array_t* restrict array = bli_sba_checkout_array( n_threads ); - - // Access the pool_t* for thread 0 and embed it into the rntm. We do - // this up-front only so that we have the rntm_t.sba_pool field - // initialized and ready for the global communicator creation below. - bli_sba_rntm_set_pool( 0, array, rntm ); - - // Set the packing block allocator field of the rntm. This will be - // inherited by all of the child threads when they make local copies of - // the rntm below. - bli_membrk_rntm_set_membrk( rntm ); - - // Allocate a global communicator for the root thrinfo_t structures. - thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); - - // Allocate an array of pthread objects and auxiliary data structs to pass - // to the thread entry functions. - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_l3_thread_decorator().pth: " ); - #endif - bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads ); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_l3_thread_decorator().pth: " ); - #endif - thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads ); - - // NOTE: We must iterate backwards so that the chief thread (thread id 0) - // can spawn all other threads before proceeding with its own computation. - for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) - { - // Set up thread data for additional threads (beyond thread 0). - datas[tid].func = func; - datas[tid].family = family; - datas[tid].schema_a = schema_a; - datas[tid].schema_b = schema_b; - datas[tid].alpha = alpha; - datas[tid].a = a; - datas[tid].b = b; - datas[tid].beta = beta; - datas[tid].c = c; - datas[tid].cntx = cntx; - datas[tid].rntm = rntm; - datas[tid].cntl = cntl; - datas[tid].tid = tid; - datas[tid].gl_comm = gl_comm; - datas[tid].array = array; - - // Spawn additional threads for ids greater than 1. - if ( tid != 0 ) - bli_pthread_create( &pthreads[tid], NULL, &bli_l3_thread_entry, &datas[tid] ); - else - bli_l3_thread_entry( ( void* )(&datas[0]) ); - } - - // We shouldn't free the global communicator since it was already freed - // by the global communicator's chief thread in bli_l3_thrinfo_free() - // (called from the thread entry function). - - // Thread 0 waits for additional threads to finish. - for ( dim_t tid = 1; tid < n_threads; tid++ ) - { - bli_pthread_join( pthreads[tid], NULL ); - } - - // Check the array_t back into the small block allocator. Similar to the - // check-out, this is done using a lock embedded within the sba to ensure - // mutual exclusion. - bli_sba_checkin_array( array ); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_l3_thread_decorator().pth: " ); - #endif - bli_free_intl( pthreads ); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_l3_thread_decorator().pth: " ); - #endif - bli_free_intl( datas ); -} - #endif diff --git a/frame/thread/bli_thrcomm_pthreads.h b/frame/thread/bli_thrcomm_pthreads.h index f24849cb33..2c2e885515 100644 --- a/frame/thread/bli_thrcomm_pthreads.h +++ b/frame/thread/bli_thrcomm_pthreads.h @@ -52,12 +52,14 @@ struct thrcomm_s void* sent_object; dim_t n_threads; -//#ifdef BLIS_USE_PTHREAD_MUTEX -// bli_pthread_mutex_t mutex; -//#endif - - //volatile bool_t barrier_sense; - bool_t barrier_sense; + // NOTE: barrier_sense was originally a gint_t-based bool_t, but upon + // redefining bool_t as bool we discovered that some gcc __atomic built-ins + // don't allow the use of bool for the variables being operated upon. + // (Specifically, this was observed of __atomic_fetch_xor(), but it likely + // applies to all other related built-ins.) Thus, we get around this by + // redefining barrier_sense as a gint_t. + //volatile gint_t barrier_sense; + gint_t barrier_sense; dim_t barrier_threads_arrived; }; #endif diff --git a/frame/thread/bli_thrcomm_single.c b/frame/thread/bli_thrcomm_single.c index 969221e7c2..cedb3c5b6e 100644 --- a/frame/thread/bli_thrcomm_single.c +++ b/frame/thread/bli_thrcomm_single.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -84,119 +84,5 @@ void bli_thrcomm_barrier( dim_t t_id, thrcomm_t* comm ) return; } -// Define a dummy function bli_l3_thread_entry(), which is needed in the -// pthreads version, so that when building Windows DLLs (with OpenMP enabled -// or no multithreading) we don't risk having an unresolved symbol. -void* bli_l3_thread_entry( void* data_void ) { return NULL; } - -void bli_l3_thread_decorator - ( - l3int_t func, - opid_t family, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - // This is part of a hack to support mixed domain in bli_gemm_front(). - // Sometimes we need to specify a non-standard schema for A and B, and - // we decided to transmit them via the schema field in the obj_t's - // rather than pass them in as function parameters. Once the values - // have been read, we immediately reset them back to their expected - // values for unpacked objects. - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); - - // For sequential execution, we use only one thread. - const dim_t n_threads = 1; - - // NOTE: The sba was initialized in bli_init(). - - // Check out an array_t from the small block allocator. This is done - // with an internal lock to ensure only one application thread accesses - // the sba at a time. bli_sba_checkout_array() will also automatically - // resize the array_t, if necessary. - array_t* restrict array = bli_sba_checkout_array( n_threads ); - - // Access the pool_t* for thread 0 and embed it into the rntm. We do - // this up-front only so that we can create the global comm below. - bli_sba_rntm_set_pool( 0, array, rntm ); - - // Set the packing block allocator field of the rntm. - bli_membrk_rntm_set_membrk( rntm ); - - // Allcoate a global communicator for the root thrinfo_t structures. - thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); - - - { - // NOTE: We don't need to create another copy of the rntm_t since - // it was already copied in one of the high-level oapi functions. - rntm_t* restrict rntm_p = rntm; - - cntl_t* cntl_use; - thrinfo_t* thread; - - const dim_t tid = 0; - - // Use the thread id to access the appropriate pool_t* within the - // array_t, and use it to set the sba_pool field within the rntm_t. - // If the pool_t* element within the array_t is NULL, it will first - // be allocated/initialized. - // NOTE: This is commented out because, in the single-threaded case, - // this is redundant since it's already been done above. - //bli_sba_rntm_set_pool( tid, array, rntm_p ); - - // NOTE: Unlike with the _openmp.c and _pthreads.c variants, we don't - // need to alias objects for A, B, and C since they were already aliased - // in bli_*_front(). However, we may add aliasing here in the future so - // that, with all three (_single.c, _openmp.c, _pthreads.c) implementations - // consistently providing local aliases, we can then eliminate aliasing - // elsewhere. - - // Create a default control tree for the operation, if needed. - bli_l3_cntl_create_if( family, schema_a, schema_b, - a, b, c, rntm_p, cntl, &cntl_use ); - - // Create the root node of the thread's thrinfo_t structure. - bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); - - func - ( - alpha, - a, - b, - beta, - c, - cntx, - rntm_p, - cntl_use, - thread - ); - - // Free the thread's local control tree. - bli_l3_cntl_free( rntm_p, cntl_use, thread ); - - // Free the current thread's thrinfo_t structure. - bli_l3_thrinfo_free( rntm_p, thread ); - } - - // We shouldn't free the global communicator since it was already freed - // by the global communicator's chief thread in bli_l3_thrinfo_free() - // (called above). - - // Check the array_t back into the small block allocator. Similar to the - // check-out, this is done using a lock embedded within the sba to ensure - // mutual exclusion. - bli_sba_checkin_array( array ); -} - #endif diff --git a/frame/thread/bli_thrcomm_single.h b/frame/thread/bli_thrcomm_single.h index f608ab37a5..c10727df22 100644 --- a/frame/thread/bli_thrcomm_single.h +++ b/frame/thread/bli_thrcomm_single.h @@ -60,8 +60,14 @@ struct thrcomm_s { void* sent_object; dim_t n_threads; - - bool_t barrier_sense; + + // NOTE: barrier_sense was originally a gint_t-based bool_t, but upon + // redefining bool_t as bool we discovered that some gcc __atomic built-ins + // don't allow the use of bool for the variables being operated upon. + // (Specifically, this was observed of __atomic_fetch_xor(), but it likely + // applies to all other related built-ins.) Thus, we get around this by + // redefining barrier_sense as a gint_t. + gint_t barrier_sense; dim_t barrier_threads_arrived; }; #endif diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 58ba57e81c..6dc4f9141c 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,8 +39,12 @@ thrinfo_t BLIS_PACKM_SINGLE_THREADED = {}; thrinfo_t BLIS_GEMM_SINGLE_THREADED = {}; thrcomm_t BLIS_SINGLE_COMM = {}; -// The global rntm_t structure, which holds the global thread settings. -static rntm_t global_rntm; +// The global rntm_t structure. (The definition resides in bli_rntm.c.) +extern rntm_t global_rntm; + +// A mutex to allow synchronous access to global_rntm. (The definition +// resides in bli_rntm.c.) +extern bli_pthread_mutex_t global_rntm_mutex; // ----------------------------------------------------------------------------- @@ -66,7 +70,7 @@ void bli_thread_range_sub thrinfo_t* thread, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* start, dim_t* end ) @@ -104,7 +108,7 @@ void bli_thread_range_sub // 13 >0 f 1 3 4 3 3 3+ // 14 >0 f 2 2 4 4 3 3+ // 15 >0 f 3 1 4 4 4 3+ - // 15 =0 f 3 1 4 4 4 3 + // 15 =0 f 3 1 4 4 4 3 // // 12 =0 t 4 0 3 3 3 3 // 12 >0 t 4 0 3+ 3 3 3 @@ -297,7 +301,7 @@ dim_t bli_thread_range_width_l dim_t bf, dim_t bf_left, double area_per_thr, - bool_t handle_edge_low + bool handle_edge_low ) { dim_t width; @@ -506,7 +510,7 @@ siz_t bli_thread_range_weighted_sub dim_t m, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* restrict j_start_thr, dim_t* restrict j_end_thr ) @@ -663,7 +667,7 @@ siz_t bli_thread_range_mdim blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx ); obj_t* x; - bool_t use_weighted; + bool use_weighted; // Use the operation family to choose the one of the two matrices // being partitioned that potentially has structure, and also to @@ -674,7 +678,7 @@ siz_t bli_thread_range_mdim // structured matrix, even though they represent part of that matrix // that will be dense and full (after packing). if ( family == BLIS_GEMM ) { x = a; use_weighted = FALSE; } - else if ( family == BLIS_HERK ) { x = c; use_weighted = TRUE; } + else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; } else if ( family == BLIS_TRMM ) { x = a; use_weighted = TRUE; } else /*family == BLIS_TRSM*/ { x = a; use_weighted = FALSE; } @@ -722,7 +726,7 @@ siz_t bli_thread_range_ndim blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx ); obj_t* x; - bool_t use_weighted; + bool use_weighted; // Use the operation family to choose the one of the two matrices // being partitioned that potentially has structure, and also to @@ -733,7 +737,7 @@ siz_t bli_thread_range_ndim // structured matrix, even though they represent part of that matrix // that will be dense and full (after packing). if ( family == BLIS_GEMM ) { x = b; use_weighted = FALSE; } - else if ( family == BLIS_HERK ) { x = c; use_weighted = TRUE; } + else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; } else if ( family == BLIS_TRMM ) { x = b; use_weighted = TRUE; } else /*family == BLIS_TRSM*/ { x = b; use_weighted = FALSE; } @@ -964,128 +968,433 @@ siz_t bli_thread_range_weighted_b2t void bli_prime_factorization( dim_t n, bli_prime_factors_t* factors ) { - factors->n = n; - factors->sqrt_n = (dim_t)sqrt(n); - factors->f = 2; + factors->n = n; + factors->sqrt_n = ( dim_t )sqrt( ( double )n ); + factors->f = 2; } dim_t bli_next_prime_factor( bli_prime_factors_t* factors ) { - // Return the prime factorization of the original number n one-by-one. - // Return 1 after all factors have been exhausted. + // Return the prime factorization of the original number n one-by-one. + // Return 1 after all factors have been exhausted. - // Looping over possible factors in increasing order assures we will - // only return prime factors (a la the Sieve of Eratosthenes). - while ( factors->f <= factors->sqrt_n ) - { - // Special cases for factors 2-7 handle all numbers not divisible by 11 - // or another larger prime. The slower loop version is used after that. - // If you use a number of threads with large prime factors you get - // what you deserve. - if ( factors->f == 2 ) - { - if ( factors->n % 2 == 0 ) - { - factors->n /= 2; - return 2; - } - factors->f = 3; - } - else if ( factors->f == 3 ) - { - if ( factors->n % 3 == 0 ) - { - factors->n /= 3; - return 3; - } - factors->f = 5; - } - else if ( factors->f == 5 ) - { - if ( factors->n % 5 == 0 ) - { - factors->n /= 5; - return 5; - } - factors->f = 7; - } - else if ( factors->f == 7 ) - { - if ( factors->n % 7 == 0 ) - { - factors->n /= 7; - return 7; - } - factors->f = 11; - } - else - { - if ( factors->n % factors->f == 0 ) - { - factors->n /= factors->f; - return factors->f; - } - factors->f++; - } - } + // Looping over possible factors in increasing order assures we will + // only return prime factors (a la the Sieve of Eratosthenes). + while ( factors->f <= factors->sqrt_n ) + { + // Special cases for factors 2-7 handle all numbers not divisible by 11 + // or another larger prime. The slower loop version is used after that. + // If you use a number of threads with large prime factors you get + // what you deserve. + if ( factors->f == 2 ) + { + if ( factors->n % 2 == 0 ) + { + factors->n /= 2; + return 2; + } + factors->f = 3; + } + else if ( factors->f == 3 ) + { + if ( factors->n % 3 == 0 ) + { + factors->n /= 3; + return 3; + } + factors->f = 5; + } + else if ( factors->f == 5 ) + { + if ( factors->n % 5 == 0 ) + { + factors->n /= 5; + return 5; + } + factors->f = 7; + } + else if ( factors->f == 7 ) + { + if ( factors->n % 7 == 0 ) + { + factors->n /= 7; + return 7; + } + factors->f = 11; + } + else + { + if ( factors->n % factors->f == 0 ) + { + factors->n /= factors->f; + return factors->f; + } + factors->f++; + } + } + + // To get here we must be out of prime factors, leaving only n (if it is + // prime) or an endless string of 1s. + dim_t tmp = factors->n; + factors->n = 1; + return tmp; +} + +bool bli_is_prime( dim_t n ) +{ + bli_prime_factors_t factors; + + bli_prime_factorization( n, &factors ); + + dim_t f = bli_next_prime_factor( &factors ); + + if ( f == n ) return TRUE; + else return FALSE; +} + +void bli_thread_partition_2x2 + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ) +{ + // Partition a number of threads into two factors nt1 and nt2 such that + // nt1/nt2 ~= work1/work2. There is a fast heuristic algorithm and a + // slower optimal algorithm (which minimizes |nt1*work2 - nt2*work1|). - // To get here we must be out of prime factors, leaving only n (if it is - // prime) or an endless string of 1s. - dim_t tmp = factors->n; - factors->n = 1; - return tmp; + // Return early small prime numbers of threads. + if ( n_thread < 4 ) + { + *nt1 = ( work1 >= work2 ? n_thread : 1 ); + *nt2 = ( work1 < work2 ? n_thread : 1 ); + + return; + } + +#if 1 + bli_thread_partition_2x2_fast( n_thread, work1, work2, nt1, nt2 ); +#else + bli_thread_partition_2x2_slow( n_thread, work1, work2, nt1, nt2 ); +#endif } -void bli_partition_2x2( dim_t nthread, dim_t work1, dim_t work2, - dim_t* nt1, dim_t* nt2 ) +//#define PRINT_FACTORS + +void bli_thread_partition_2x2_fast + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ) { + // Compute with these local variables until the end of the function, at + // which time we will save the values back to nt1 and nt2. + dim_t tn1 = 1; + dim_t tn2 = 1; + + // Both algorithms need the prime factorization of n_thread. + bli_prime_factors_t factors; + bli_prime_factorization( n_thread, &factors ); + + // Fast algorithm: assign prime factors in increasing order to whichever + // partition has more work to do. The work is divided by the number of + // threads assigned at each iteration. This algorithm is sub-optimal in + // some cases. We attempt to mitigate the cases that involve at least one + // factor of 2. For example, in the partitioning of 12 with equal work + // this algorithm tentatively finds 6x2. This factorization involves a + // factor of 2 that can be reallocated, allowing us to convert it to the + // optimal solution of 4x3. But some cases cannot be corrected this way + // because they do not contain a factor of 2. For example, this algorithm + // factors 105 (with equal work) into 21x5 whereas 7x15 would be optimal. + + #ifdef PRINT_FACTORS + printf( "w1 w2 = %d %d (initial)\n", (int)work1, (int)work2 ); + #endif + + dim_t f; + while ( ( f = bli_next_prime_factor( &factors ) ) > 1 ) + { + #ifdef PRINT_FACTORS + printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d ... f = %d\n", + (int)work1, (int)work2, (int)tn1, (int)tn2, (int)f ); + #endif + + if ( work1 > work2 ) { work1 /= f; tn1 *= f; } + else { work2 /= f; tn2 *= f; } + } + + #ifdef PRINT_FACTORS + printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d\n", + (int)work1, (int)work2, (int)tn1, (int)tn2 ); + #endif + + // Sometimes the last factor applied is prime. For example, on a square + // matrix, we tentatively arrive (from the logic above) at: + // - a 2x6 factorization when given 12 ways of parallelism + // - a 2x10 factorization when given 20 ways of parallelism + // - a 2x14 factorization when given 28 ways of parallelism + // These factorizations are suboptimal under the assumption that we want + // the parallelism to be as balanced as possible. Below, we make a final + // attempt at rebalancing nt1 and nt2 by checking to see if the gap between + // work1 and work2 is narrower if we reallocate a factor of 2. + if ( work1 > work2 ) + { + // Example: nt = 12 + // w1 w2 (initial) = 3600 3600; nt1 nt2 = 1 1 + // w1 w2 (tentative) = 1800 600; nt1 nt2 = 2 6 + // w1 w2 (ideal) = 900 1200; nt1 nt2 = 4 3 + if ( tn2 % 2 == 0 ) + { + dim_t diff = work1 - work2; + dim_t diff_mod = bli_abs( work1/2 - work2*2 ); + + if ( diff_mod < diff ) { tn1 *= 2; tn2 /= 2; } + } + } + else if ( work1 < work2 ) + { + // Example: nt = 40 + // w1 w2 (initial) = 3600 3600; nt1 nt2 = 1 1 + // w1 w2 (tentative) = 360 900; nt1 nt2 = 10 4 + // w1 w2 (ideal) = 720 450; nt1 nt2 = 5 8 + if ( tn1 % 2 == 0 ) + { + dim_t diff = work2 - work1; + dim_t diff_mod = bli_abs( work2/2 - work1*2 ); + + if ( diff_mod < diff ) { tn1 /= 2; tn2 *= 2; } + } + } + + #ifdef PRINT_FACTORS + printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d (final)\n", + (int)work1, (int)work2, (int)tn1, (int)tn2 ); + #endif + + // Save the final result. + *nt1 = tn1; + *nt2 = tn2; +} + +#include "limits.h" + +void bli_thread_partition_2x2_slow + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ) +{ + // Slow algorithm: exhaustively constructs all factor pairs of n_thread and + // chooses the best one. + + // Compute with these local variables until the end of the function, at + // which time we will save the values back to nt1 and nt2. + dim_t tn1 = 1; + dim_t tn2 = 1; + + // Both algorithms need the prime factorization of n_thread. + bli_prime_factors_t factors; + bli_prime_factorization( n_thread, &factors ); + + // Eight prime factors handles n_thread up to 223092870. + dim_t fact[8]; + dim_t mult[8]; + + // There is always at least one prime factor, so use if for initialization. + dim_t nfact = 1; + fact[0] = bli_next_prime_factor( &factors ); + mult[0] = 1; + + // Collect the remaining prime factors, accounting for multiplicity of + // repeated factors. + dim_t f; + while ( ( f = bli_next_prime_factor( &factors ) ) > 1 ) + { + if ( f == fact[nfact-1] ) + { + mult[nfact-1]++; + } + else + { + nfact++; + fact[nfact-1] = f; + mult[nfact-1] = 1; + } + } + + // Now loop over all factor pairs. A single factor pair is denoted by how + // many of each prime factor are included in the first factor (ntaken). + dim_t ntake[8] = {0}; + dim_t min_diff = INT_MAX; + + // Loop over how many prime factors to assign to the first factor in the + // pair, for each prime factor. The total number of iterations is + // \Prod_{i=0}^{nfact-1} mult[i]. + bool done = FALSE; + while ( !done ) + { + dim_t x = 1; + dim_t y = 1; + + // Form the factors by integer exponentiation and accumulation. + for ( dim_t i = 0 ; i < nfact ; i++ ) + { + x *= bli_ipow( fact[i], ntake[i] ); + y *= bli_ipow( fact[i], mult[i]-ntake[i] ); + } + + // Check if this factor pair is optimal by checking + // |nt1*work2 - nt2*work1|. + dim_t diff = llabs( x*work2 - y*work1 ); + if ( diff < min_diff ) + { + min_diff = diff; + tn1 = x; + tn2 = y; + } + + // Go to the next factor pair by doing an "odometer loop". + for ( dim_t i = 0 ; i < nfact ; i++ ) + { + if ( ++ntake[i] > mult[i] ) + { + ntake[i] = 0; + if ( i == nfact-1 ) done = TRUE; + else continue; + } + break; + } + } + + // Save the final result. + *nt1 = tn1; + *nt2 = tn2; +} + +#if 0 +void bli_thread_partition_2x2_orig + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ) +{ + // Copy nt1 and nt2 to local variables and then compute with those local + // variables until the end of the function, at which time we will save the + // values back to nt1 and nt2. + dim_t tn1; // = *nt1; + dim_t tn2; // = *nt2; + // Partition a number of threads into two factors nt1 and nt2 such that // nt1/nt2 ~= work1/work2. There is a fast heuristic algorithm and a // slower optimal algorithm (which minimizes |nt1*work2 - nt2*work1|). // Return early small prime numbers of threads. - if (nthread < 4) + if ( n_thread < 4 ) { - *nt1 = ( work1 >= work2 ? nthread : 1 ); - *nt2 = ( work1 < work2 ? nthread : 1 ); + tn1 = ( work1 >= work2 ? n_thread : 1 ); + tn2 = ( work1 < work2 ? n_thread : 1 ); + + return; } - *nt1 = 1; - *nt2 = 1; + tn1 = 1; + tn2 = 1; - // Both algorithms need the prime factorization of nthread. + // Both algorithms need the prime factorization of n_thread. bli_prime_factors_t factors; - bli_prime_factorization( nthread, &factors ); + bli_prime_factorization( n_thread, &factors ); - #if 1 +#if 1 // Fast algorithm: assign prime factors in increasing order to whichever // partition has more work to do. The work is divided by the number of - // threads assigned at each iteration. This algorithm is sub-optimal, - // for example in the partitioning of 12 with equal work (optimal solution - // is 4x3, this algorithm finds 6x2). + // threads assigned at each iteration. This algorithm is sub-optimal in + // some cases. We attempt to mitigate the cases that involve at least one + // factor of 2. For example, in the partitioning of 12 with equal work + // this algorithm tentatively finds 6x2. This factorization involves a + // factor of 2 that can be reallocated, allowing us to convert it to the + // optimal solution of 4x3. But some cases cannot be corrected this way + // because they do not contain a factor of 2. For example, this algorithm + // factors 105 (with equal work) into 21x5 whereas 7x15 would be optimal. + + //printf( "w1 w2 = %d %d (initial)\n", (int)work1, (int)work2 ); dim_t f; while ( ( f = bli_next_prime_factor( &factors ) ) > 1 ) { + //printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d ... f = %d\n", (int)work1, (int)work2, (int)tn1, (int)tn2, (int)f ); + if ( work1 > work2 ) { work1 /= f; - *nt1 *= f; + tn1 *= f; } else { work2 /= f; - *nt2 *= f; + tn2 *= f; } } - #else + //printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d\n", (int)work1, (int)work2, (int)tn1, (int)tn2 ); + + // Sometimes the last factor applied is prime. For example, on a square + // matrix, we tentatively arrive (from the logic above) at: + // - a 2x6 factorization when given 12 ways of parallelism + // - a 2x10 factorization when given 20 ways of parallelism + // - a 2x14 factorization when given 28 ways of parallelism + // These factorizations are suboptimal under the assumption that we want + // the parallelism to be as balanced as possible. Below, we make a final + // attempt at rebalancing nt1 and nt2 by checking to see if the gap between + // work1 and work2 is narrower if we reallocate a factor of 2. + if ( work1 > work2 ) + { + // Example: nt = 12 + // w1 w2 (initial) = 3600 3600; nt1 nt2 = 1 1 + // w1 w2 (tentative) = 1800 600; nt1 nt2 = 2 6 + // w1 w2 (ideal) = 900 1200; nt1 nt2 = 4 3 + if ( tn2 % 2 == 0 ) + { + dim_t diff = work1 - work2; + dim_t diff_mod = bli_abs( work1/2 - work2*2 ); + + if ( diff_mod < diff ) { tn1 *= 2; tn2 /= 2; } + } + } + else if ( work1 < work2 ) + { + // Example: nt = 40 + // w1 w2 (initial) = 3600 3600; nt1 nt2 = 1 1 + // w1 w2 (tentative) = 360 900; nt1 nt2 = 10 4 + // w1 w2 (ideal) = 720 450; nt1 nt2 = 5 8 + if ( tn1 % 2 == 0 ) + { + dim_t diff = work2 - work1; + dim_t diff_mod = bli_abs( work2/2 - work1*2 ); + + if ( diff_mod < diff ) { tn1 /= 2; tn2 *= 2; } + } + } + + //printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d (final)\n", (int)work1, (int)work2, (int)tn1, (int)tn2 ); + +#else - // Slow algorithm: exhaustively constructs all factor pairs of nthread and + // Slow algorithm: exhaustively constructs all factor pairs of n_thread and // chooses the best one. - // Eight prime factors handles nthread up to 223092870. + // Eight prime factors handles n_thread up to 223092870. dim_t fact[8]; dim_t mult[8]; @@ -1119,7 +1428,7 @@ void bli_partition_2x2( dim_t nthread, dim_t work1, dim_t work2, // Loop over how many prime factors to assign to the first factor in the // pair, for each prime factor. The total number of iterations is // \Prod_{i=0}^{nfact-1} mult[i]. - bool done = false; + bool done = FALSE; while ( !done ) { dim_t x = 1; @@ -1138,8 +1447,8 @@ void bli_partition_2x2( dim_t nthread, dim_t work1, dim_t work2, if ( diff < min_diff ) { min_diff = diff; - *nt1 = x; - *nt2 = y; + tn1 = x; + tn2 = y; } // Go to the next factor pair by doing an "odometer loop". @@ -1148,15 +1457,21 @@ void bli_partition_2x2( dim_t nthread, dim_t work1, dim_t work2, if ( ++ntake[i] > mult[i] ) { ntake[i] = 0; - if ( i == nfact-1 ) done = true; + if ( i == nfact-1 ) done = TRUE; else continue; } break; } } - #endif +#endif + + + // Save the final result. + *nt1 = tn1; + *nt2 = tn2; } +#endif // ----------------------------------------------------------------------------- @@ -1178,73 +1493,16 @@ dim_t bli_lcm( dim_t x, dim_t y) dim_t bli_ipow( dim_t base, dim_t power ) { - dim_t p = 1; - - for ( dim_t mask = 0x1 ; mask <= power ; mask <<= 1 ) - { - if ( power & mask ) p *= base; - base *= base; - } - - return p; -} -// ----------------------------------------------------------------------------- - -dim_t bli_thread_get_env( const char* env, dim_t fallback ) -{ - dim_t r_val; - char* str; - - // Query the environment variable and store the result in str. - str = getenv( env ); + dim_t p = 1; - // Set the return value based on the string obtained from getenv(). - if ( str != NULL ) + for ( dim_t mask = 0x1 ; mask <= power ; mask <<= 1 ) { - // If there was no error, convert the string to an integer and - // prepare to return that integer. - r_val = strtol( str, NULL, 10 ); + if ( power & mask ) p *= base; + base *= base; } - else - { - // If there was an error, use the "fallback" as the return value. - r_val = fallback; - } - - return r_val; -} - -#if 0 -void bli_thread_set_env( const char* env, dim_t value ) -{ - dim_t r_val; - char value_str[32]; - const char* fs_32 = "%u"; - const char* fs_64 = "%lu"; - - // Convert the string to an integer, but vary the format specifier - // depending on the integer type size. - if ( bli_info_get_int_type_size() == 32 ) sprintf( value_str, fs_32, value ); - else sprintf( value_str, fs_64, value ); - - // Set the environment variable using the string we just wrote to via - // sprintf(). (The 'TRUE' argument means we want to overwrite the current - // value if the environment variable already exists.) - r_val = bli_setenv( env, value_str, TRUE ); - - // Check the return value in case something went horribly wrong. - if ( r_val == -1 ) - { - char err_str[128]; - // Query the human-readable error string corresponding to errno. - strerror_r( errno, err_str, 128 ); - - // Print the error message. - bli_print_msg( err_str, __FILE__, __LINE__ ); - } + return p; } -#endif // ----------------------------------------------------------------------------- @@ -1298,9 +1556,6 @@ dim_t bli_thread_get_num_threads( void ) // ---------------------------------------------------------------------------- -// A mutex to allow synchronous access to global_rntm. -static bli_pthread_mutex_t global_rntm_mutex = BLIS_PTHREAD_MUTEX_INITIALIZER; - void bli_thread_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir ) { // We must ensure that global_rntm has been initialized. @@ -1331,22 +1586,6 @@ void bli_thread_set_num_threads( dim_t n_threads ) // ---------------------------------------------------------------------------- -void bli_thread_init_rntm( rntm_t* rntm ) -{ - // We must ensure that global_rntm has been initialized. - bli_init_once(); - - // Acquire the mutex protecting global_rntm. - bli_pthread_mutex_lock( &global_rntm_mutex ); - - *rntm = global_rntm; - - // Release the mutex protecting global_rntm. - bli_pthread_mutex_unlock( &global_rntm_mutex ); -} - -// ---------------------------------------------------------------------------- - void bli_thread_init_rntm_from_env ( rntm_t* rntm @@ -1356,30 +1595,31 @@ void bli_thread_init_rntm_from_env // function is only called from bli_thread_init(), which is only called // by bli_init_once(). + bool auto_factor = FALSE; dim_t nt; dim_t jc, pc, ic, jr, ir; #ifdef BLIS_ENABLE_MULTITHREADING // Try to read BLIS_NUM_THREADS first. - nt = bli_thread_get_env( "BLIS_NUM_THREADS", -1 ); + nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); // If BLIS_NUM_THREADS was not set, try to read OMP_NUM_THREADS. if ( nt == -1 ) - nt = bli_thread_get_env( "OMP_NUM_THREADS", -1 ); + nt = bli_env_get_var( "OMP_NUM_THREADS", -1 ); // Read the environment variables for the number of threads (ways // of parallelism) for each individual loop. - jc = bli_thread_get_env( "BLIS_JC_NT", -1 ); - pc = bli_thread_get_env( "BLIS_PC_NT", -1 ); - ic = bli_thread_get_env( "BLIS_IC_NT", -1 ); - jr = bli_thread_get_env( "BLIS_JR_NT", -1 ); - ir = bli_thread_get_env( "BLIS_IR_NT", -1 ); + jc = bli_env_get_var( "BLIS_JC_NT", -1 ); + pc = bli_env_get_var( "BLIS_PC_NT", -1 ); + ic = bli_env_get_var( "BLIS_IC_NT", -1 ); + jr = bli_env_get_var( "BLIS_JR_NT", -1 ); + ir = bli_env_get_var( "BLIS_IR_NT", -1 ); // If any BLIS_*_NT environment variable was set, then we ignore the // value of BLIS_NUM_THREADS or OMP_NUM_THREADS and use the - // BLIS_*_NT values instead (with unset variables being assumed to - // contain 1). + // BLIS_*_NT values instead (with unset variables being treated as if + // they contained 1). if ( jc != -1 || pc != -1 || ic != -1 || jr != -1 || ir != -1 ) { if ( jc == -1 ) jc = 1; @@ -1392,9 +1632,14 @@ void bli_thread_init_rntm_from_env nt = -1; } - // By this time, either nt is set and the ways for each loop - // are all unset, OR nt is unset and the ways for each loop - // are all set. + // By this time, one of the following conditions holds: + // - nt is -1 and the ways for each loop are -1. + // - nt is -1 and the ways for each loop are all set. + // - nt is set and the ways for each loop are -1. + + // If nt is set (ie: not -1), then we know we will perform an automatic + // thread factorization (later, in bli_rntm.c). + if ( nt != -1 ) auto_factor = TRUE; #else @@ -1406,6 +1651,7 @@ void bli_thread_init_rntm_from_env #endif // Save the results back in the runtime object. + bli_rntm_set_auto_factor_only( auto_factor, rntm ); bli_rntm_set_num_threads_only( nt, rntm ); bli_rntm_set_ways_only( jc, pc, ic, jr, ir, rntm ); diff --git a/frame/thread/bli_thread.h b/frame/thread/bli_thread.h index 6680f536ea..d4880c4c85 100644 --- a/frame/thread/bli_thread.h +++ b/frame/thread/bli_thread.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,28 +42,34 @@ // Include thread info (thrinfo_t) object definitions and prototypes. #include "bli_thrinfo.h" +#include "bli_thrinfo_sup.h" // Include some operation-specific thrinfo_t prototypes. // Note that the bli_packm_thrinfo.h must be included before the others! #include "bli_packm_thrinfo.h" #include "bli_l3_thrinfo.h" +// Include the level-3 thread decorator and related definitions and prototypes +// for the conventional code path. +#include "bli_l3_decor.h" + +// Include the level-3 thread decorator and related definitions and prototypes +// for the sup code path. +#include "bli_l3_sup_decor.h" + // Initialization-related prototypes. void bli_thread_init( void ); void bli_thread_finalize( void ); -#ifdef _MSC_VER -#define strerror_r(errno,buf,len) strerror_s(buf,len,errno) -#endif - // Thread range-related prototypes. +BLIS_EXPORT_BLIS void bli_thread_range_sub ( thrinfo_t* thread, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* start, dim_t* end ); @@ -120,7 +126,7 @@ dim_t bli_thread_range_width_l dim_t bf, dim_t bf_left, double area_per_thr, - bool_t handle_edge_low + bool handle_edge_low ); siz_t bli_find_area_trap_l ( @@ -136,42 +142,11 @@ siz_t bli_thread_range_weighted_sub dim_t m, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* restrict j_start_thr, dim_t* restrict j_end_thr ); - - -// Level-3 internal function type -typedef void (*l3int_t) - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ); - -// Level-3 thread decorator prototype -void bli_l3_thread_decorator - ( - l3int_t func, - opid_t family, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ); - // ----------------------------------------------------------------------------- // Factorization and partitioning prototypes @@ -185,8 +160,32 @@ typedef struct void bli_prime_factorization(dim_t n, bli_prime_factors_t* factors); dim_t bli_next_prime_factor(bli_prime_factors_t* factors); +bool bli_is_prime( dim_t n ); -void bli_partition_2x2(dim_t nthread, dim_t work1, dim_t work2, dim_t* nt1, dim_t* nt2); +void bli_thread_partition_2x2 + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ); +void bli_thread_partition_2x2_slow + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ); +void bli_thread_partition_2x2_fast + ( + dim_t n_thread, + dim_t work1, + dim_t work2, + dim_t* restrict nt1, + dim_t* restrict nt2 + ); // ----------------------------------------------------------------------------- @@ -196,9 +195,6 @@ dim_t bli_ipow( dim_t base, dim_t power ); // ----------------------------------------------------------------------------- -BLIS_EXPORT_BLIS dim_t bli_thread_get_env( const char* env, dim_t fallback ); -//void bli_thread_set_env( const char* env, dim_t value ); - BLIS_EXPORT_BLIS dim_t bli_thread_get_jc_nt( void ); BLIS_EXPORT_BLIS dim_t bli_thread_get_pc_nt( void ); BLIS_EXPORT_BLIS dim_t bli_thread_get_ic_nt( void ); @@ -209,18 +205,16 @@ BLIS_EXPORT_BLIS dim_t bli_thread_get_num_threads( void ); BLIS_EXPORT_BLIS void bli_thread_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir ); BLIS_EXPORT_BLIS void bli_thread_set_num_threads( dim_t value ); -BLIS_EXPORT_BLIS void bli_thread_init_rntm( rntm_t* rntm ); - void bli_thread_init_rntm_from_env( rntm_t* rntm ); // ----------------------------------------------------------------------------- -static void bli_thread_range_jrir_rr +BLIS_INLINE void bli_thread_range_jrir_rr ( thrinfo_t* thread, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* start, dim_t* end, dim_t* inc @@ -232,12 +226,12 @@ static void bli_thread_range_jrir_rr *end = n; } -static void bli_thread_range_jrir_sl +BLIS_INLINE void bli_thread_range_jrir_sl ( thrinfo_t* thread, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* start, dim_t* end, dim_t* inc @@ -248,12 +242,12 @@ static void bli_thread_range_jrir_sl *inc = 1; } -static void bli_thread_range_jrir +BLIS_INLINE void bli_thread_range_jrir ( thrinfo_t* thread, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* start, dim_t* end, dim_t* inc @@ -270,7 +264,7 @@ static void bli_thread_range_jrir } #if 0 -static void bli_thread_range_weighted_jrir +BLIS_INLINE void bli_thread_range_weighted_jrir ( thrinfo_t* thread, doff_t diagoff, @@ -278,7 +272,7 @@ static void bli_thread_range_weighted_jrir dim_t m, dim_t n, dim_t bf, - bool_t handle_edge_low, + bool handle_edge_low, dim_t* start, dim_t* end, dim_t* inc diff --git a/frame/thread/bli_thrinfo.c b/frame/thread/bli_thrinfo.c index fdcf31f1d0..0282be170f 100644 --- a/frame/thread/bli_thrinfo.c +++ b/frame/thread/bli_thrinfo.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ thrinfo_t* bli_thrinfo_create dim_t ocomm_id, dim_t n_way, dim_t work_id, - bool_t free_comm, + bool free_comm, bszid_t bszid, thrinfo_t* sub_node ) @@ -73,20 +73,20 @@ void bli_thrinfo_init dim_t ocomm_id, dim_t n_way, dim_t work_id, - bool_t free_comm, + bool free_comm, bszid_t bszid, thrinfo_t* sub_node ) { - thread->ocomm = ocomm; - thread->ocomm_id = ocomm_id; - thread->n_way = n_way; - thread->work_id = work_id; - thread->free_comm = free_comm; - thread->bszid = bszid; - - thread->sub_prenode = NULL; - thread->sub_node = sub_node; + bli_thrinfo_set_ocomm( ocomm, thread ); + bli_thrinfo_set_ocomm_id( ocomm_id, thread ); + bli_thrinfo_set_n_way( n_way, thread ); + bli_thrinfo_set_work_id( work_id, thread ); + bli_thrinfo_set_free_comm( free_comm, thread ); + bli_thrinfo_set_bszid( bszid, thread ); + + bli_thrinfo_set_sub_node( sub_node, thread ); + bli_thrinfo_set_sub_prenode( NULL, thread ); } void bli_thrinfo_init_single @@ -298,6 +298,24 @@ thrinfo_t* bli_thrinfo_create_for_cntl thrinfo_t* thread_par ) { + // If we are running with a single thread, all of the code can be reduced + // and simplified to this. + if ( bli_rntm_calc_num_threads( rntm ) == 1 ) + { + thrinfo_t* thread_chl = bli_thrinfo_create + ( + rntm, // rntm + &BLIS_SINGLE_COMM, // ocomm + 0, // ocomm_id + 1, // n_way + 0, // work_id + FALSE, // free_comm + BLIS_NO_PART, // bszid + NULL // sub_node + ); + return thread_chl; + } + thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ]; thrcomm_t** new_comms = NULL; @@ -332,15 +350,17 @@ thrinfo_t* bli_thrinfo_create_for_cntl // pointers. if ( bli_thread_am_ochief( thread_par ) ) { + err_t r_val; + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); + new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ), &r_val ); else new_comms = static_comms; } // Broadcast the temporary array to all threads in the parent's // communicator. - new_comms = bli_thread_obroadcast( thread_par, new_comms ); + new_comms = bli_thread_broadcast( thread_par, new_comms ); // Chiefs in the child communicator allocate the communicator // object and store it in the array element corresponding to the @@ -348,7 +368,7 @@ thrinfo_t* bli_thrinfo_create_for_cntl if ( child_comm_id == 0 ) new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); - bli_thread_obarrier( thread_par ); + bli_thread_barrier( thread_par ); // All threads create a new thrinfo_t node using the communicator // that was created by their chief, as identified by parent_work_id. @@ -364,7 +384,7 @@ thrinfo_t* bli_thrinfo_create_for_cntl NULL // sub_node ); - bli_thread_obarrier( thread_par ); + bli_thread_barrier( thread_par ); // The parent's chief thread frees the temporary array of thrcomm_t // pointers. @@ -477,7 +497,7 @@ thrinfo_t* bli_thrinfo_create_for_cntl_prenode const dim_t child_comm_id = parent_comm_id % child_nt_in; const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way ); - bli_thread_obarrier( thread_par ); + bli_thread_barrier( thread_par ); // NOTE: Recall that parent_comm_id == child_comm_id, so checking for the // parent's chief-ness is equivalent to checking for chief-ness in the new @@ -488,7 +508,7 @@ thrinfo_t* bli_thrinfo_create_for_cntl_prenode // Broadcast the new thrcomm_t address to the other threads in the // parent's group. - new_comm = bli_thread_obroadcast( thread_par, new_comm ); + new_comm = bli_thread_broadcast( thread_par, new_comm ); // All threads create a new thrinfo_t node using the communicator // that was created by their chief, as identified by parent_work_id. @@ -504,7 +524,7 @@ thrinfo_t* bli_thrinfo_create_for_cntl_prenode NULL // sub_node ); - bli_thread_obarrier( thread_par ); + bli_thread_barrier( thread_par ); return thread_chl; } diff --git a/frame/thread/bli_thrinfo.h b/frame/thread/bli_thrinfo.h index 2b3d2e8097..8e5a6da3b7 100644 --- a/frame/thread/bli_thrinfo.h +++ b/frame/thread/bli_thrinfo.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,7 +56,7 @@ struct thrinfo_s // this is field is true, but when nodes are created that share the same // communicators as other nodes (such as with packm nodes), this is set // to false. - bool_t free_comm; + bool free_comm; // The bszid_t to help identify the node. This is mostly only useful when // debugging or tracing the allocation and release of thrinfo_t nodes. @@ -75,78 +75,108 @@ typedef struct thrinfo_s thrinfo_t; // thrinfo_t query (field only) -static dim_t bli_thread_num_threads( thrinfo_t* t ) +BLIS_INLINE dim_t bli_thread_num_threads( thrinfo_t* t ) { return (t->ocomm)->n_threads; } -static dim_t bli_thread_ocomm_id( thrinfo_t* t ) +BLIS_INLINE dim_t bli_thread_ocomm_id( thrinfo_t* t ) { return t->ocomm_id; } -static dim_t bli_thread_n_way( thrinfo_t* t ) +BLIS_INLINE dim_t bli_thread_n_way( thrinfo_t* t ) { return t->n_way; } -static dim_t bli_thread_work_id( thrinfo_t* t ) +BLIS_INLINE dim_t bli_thread_work_id( thrinfo_t* t ) { return t->work_id; } -static thrcomm_t* bli_thrinfo_ocomm( thrinfo_t* t ) +BLIS_INLINE thrcomm_t* bli_thrinfo_ocomm( thrinfo_t* t ) { return t->ocomm; } -static bool_t bli_thrinfo_needs_free_comm( thrinfo_t* t ) +BLIS_INLINE bool bli_thrinfo_needs_free_comm( thrinfo_t* t ) { return t->free_comm; } -static dim_t bli_thread_bszid( thrinfo_t* t ) +BLIS_INLINE dim_t bli_thread_bszid( thrinfo_t* t ) { return t->bszid; } -static thrinfo_t* bli_thrinfo_sub_node( thrinfo_t* t ) +BLIS_INLINE thrinfo_t* bli_thrinfo_sub_node( thrinfo_t* t ) { return t->sub_node; } -static thrinfo_t* bli_thrinfo_sub_prenode( thrinfo_t* t ) +BLIS_INLINE thrinfo_t* bli_thrinfo_sub_prenode( thrinfo_t* t ) { return t->sub_prenode; } // thrinfo_t query (complex) -static bool_t bli_thread_am_ochief( thrinfo_t* t ) +BLIS_INLINE bool bli_thread_am_ochief( thrinfo_t* t ) { return t->ocomm_id == 0; } // thrinfo_t modification -static void bli_thrinfo_set_sub_node( thrinfo_t* sub_node, thrinfo_t* t ) +BLIS_INLINE void bli_thrinfo_set_ocomm( thrcomm_t* ocomm, thrinfo_t* t ) +{ + t->ocomm = ocomm; +} + +BLIS_INLINE void bli_thrinfo_set_ocomm_id( dim_t ocomm_id, thrinfo_t* t ) +{ + t->ocomm_id = ocomm_id; +} + +BLIS_INLINE void bli_thrinfo_set_n_way( dim_t n_way, thrinfo_t* t ) +{ + t->n_way = n_way; +} + +BLIS_INLINE void bli_thrinfo_set_work_id( dim_t work_id, thrinfo_t* t ) +{ + t->work_id = work_id; +} + +BLIS_INLINE void bli_thrinfo_set_free_comm( bool free_comm, thrinfo_t* t ) +{ + t->free_comm = free_comm; +} + +BLIS_INLINE void bli_thrinfo_set_bszid( bszid_t bszid, thrinfo_t* t ) +{ + t->bszid = bszid; +} + +BLIS_INLINE void bli_thrinfo_set_sub_node( thrinfo_t* sub_node, thrinfo_t* t ) { t->sub_node = sub_node; } -static void bli_thrinfo_set_sub_prenode( thrinfo_t* sub_prenode, thrinfo_t* t ) +BLIS_INLINE void bli_thrinfo_set_sub_prenode( thrinfo_t* sub_prenode, thrinfo_t* t ) { t->sub_prenode = sub_prenode; } // other thrinfo_t-related functions -static void* bli_thread_obroadcast( thrinfo_t* t, void* p ) +BLIS_INLINE void* bli_thread_broadcast( thrinfo_t* t, void* p ) { return bli_thrcomm_bcast( t->ocomm_id, p, t->ocomm ); } -static void bli_thread_obarrier( thrinfo_t* t ) +BLIS_INLINE void bli_thread_barrier( thrinfo_t* t ) { bli_thrcomm_barrier( t->ocomm_id, t->ocomm ); } @@ -163,7 +193,7 @@ thrinfo_t* bli_thrinfo_create dim_t ocomm_id, dim_t n_way, dim_t work_id, - bool_t free_comm, + bool free_comm, bszid_t bszid, thrinfo_t* sub_node ); @@ -175,7 +205,7 @@ void bli_thrinfo_init dim_t ocomm_id, dim_t n_way, dim_t work_id, - bool_t free_comm, + bool free_comm, bszid_t bszid, thrinfo_t* sub_node ); diff --git a/frame/thread/bli_thrinfo_sup.c b/frame/thread/bli_thrinfo_sup.c new file mode 100644 index 0000000000..881990f78d --- /dev/null +++ b/frame/thread/bli_thrinfo_sup.c @@ -0,0 +1,290 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_thrinfo_sup_grow + ( + rntm_t* rntm, + bszid_t* bszid_par, + thrinfo_t* thread + ) +{ + if ( thread == &BLIS_GEMM_SINGLE_THREADED || + thread == &BLIS_PACKM_SINGLE_THREADED ) return; + + // NOTE: If bli_thrinfo_sup_rgrow() is being called, the sub_node field will + // always be non-NULL, and so there's no need to check it. + //if ( bli_cntl_sub_node( cntl ) != NULL ) + { + // We only need to take action if the thrinfo_t sub-node is NULL; if it + // is non-NULL, then it has already been created and we'll use it as-is. + if ( bli_thrinfo_sub_node( thread ) == NULL ) + { + // Create a new node (or, if needed, multiple nodes) along the + // main sub-node branch of the tree and return the pointer to the + // (highest) child. + thrinfo_t* thread_child = bli_thrinfo_sup_rgrow + ( + rntm, + bszid_par, + &bszid_par[1], + thread + ); + + // Attach the child thrinfo_t node for the primary branch to its + // parent structure. + bli_thrinfo_set_sub_node( thread_child, thread ); + } + } +} + +// ----------------------------------------------------------------------------- + +thrinfo_t* bli_thrinfo_sup_rgrow + ( + rntm_t* rntm, + bszid_t* bszid_par, + bszid_t* bszid_cur, + thrinfo_t* thread_par + ) +{ + thrinfo_t* thread_cur; + + // We must handle two cases: those where the next node in the + // control tree is a partitioning node, and those where it is + // a non-partitioning (ie: packing) node. + if ( *bszid_cur != BLIS_NO_PART ) + { + // Create the child thrinfo_t node corresponding to cntl_cur, + // with cntl_par being the parent. + thread_cur = bli_thrinfo_sup_create_for_cntl + ( + rntm, + bszid_par, + bszid_cur, + thread_par + ); + } + else // if ( *bszid_cur == BLIS_NO_PART ) + { + // Recursively grow the thread structure and return the top-most + // thrinfo_t node of that segment. + thrinfo_t* thread_seg = bli_thrinfo_sup_rgrow + ( + rntm, + bszid_par, + &bszid_cur[1], + thread_par + ); + + // Create a thrinfo_t node corresponding to cntl_cur. Since the + // corresponding cntl node, cntl_cur, is a non-partitioning node + // (bszid = BLIS_NO_PART), this means it's a packing node. Packing + // thrinfo_t nodes are formed differently than those corresponding to + // partitioning nodes; specifically, their work_id's are set equal to + // the their comm_id's. Also, notice that the free_comm field is set + // to FALSE since cntl_cur is a non-partitioning node. The reason: + // the communicator used here will be freed when thread_seg, or one + // of its descendents, is freed. + thread_cur = bli_thrinfo_create + ( + rntm, // rntm + bli_thrinfo_ocomm( thread_seg ), // ocomm + bli_thread_ocomm_id( thread_seg ), // ocomm_id + bli_rntm_calc_num_threads_in( bszid_cur, rntm ), // n_way + bli_thread_ocomm_id( thread_seg ), // work_id + FALSE, // free_comm + BLIS_NO_PART, // bszid + thread_seg // sub_node + ); + } + + return thread_cur; +} + +#define BLIS_NUM_STATIC_COMMS 80 + +thrinfo_t* bli_thrinfo_sup_create_for_cntl + ( + rntm_t* rntm, + bszid_t* bszid_par, + bszid_t* bszid_chl, + thrinfo_t* thread_par + ) +{ + // If we are running with a single thread, all of the code can be reduced + // and simplified to this. + if ( bli_rntm_calc_num_threads( rntm ) == 1 ) + { + thrinfo_t* thread_chl = bli_thrinfo_create + ( + rntm, // rntm + &BLIS_SINGLE_COMM, // ocomm + 0, // ocomm_id + 1, // n_way + 0, // work_id + FALSE, // free_comm + BLIS_NO_PART, // bszid + NULL // sub_node + ); + + return thread_chl; + } + + // The remainder of this function handles the cases involving the use of + // multiple BLIS threads. + + if ( bli_rntm_pack_a( rntm ) == FALSE && + bli_rntm_pack_b( rntm ) == FALSE ) + { + // If we are packing neither A nor B, there are no broadcasts or barriers + // needed to synchronize threads (since all threads can work completely + // independently). In this special case situation, the thrinfo_t can be + // created with much simpler logic. + + const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par ); + + // Compute: + // - the number of threads inside the new child comm, + // - the current thread's id within the new communicator, + // - the current thread's work id, given the ways of parallelism + // to be obtained within the next loop. + const dim_t child_nt_in = bli_rntm_calc_num_threads_in( bszid_chl, rntm ); + const dim_t child_n_way = bli_rntm_ways_for( *bszid_chl, rntm ); + const dim_t child_comm_id = parent_comm_id % child_nt_in; + const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way ); + + // All threads create a new thrinfo_t node using the communicator + // that was created by their chief, as identified by parent_work_id. + thrinfo_t* thread_chl = bli_thrinfo_create + ( + rntm, // rntm + NULL, // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + TRUE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); + + return thread_chl; + } + else + { + // If we are packing at least one of A or B, then we use the general + // approach that employs broadcasts and barriers. + + thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ]; + thrcomm_t** new_comms = NULL; + + const dim_t parent_nt_in = bli_thread_num_threads( thread_par ); + const dim_t parent_n_way = bli_thread_n_way( thread_par ); + const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par ); + const dim_t parent_work_id = bli_thread_work_id( thread_par ); + + // Sanity check: make sure the number of threads in the parent's + // communicator is divisible by the number of new sub-groups. + if ( parent_nt_in % parent_n_way != 0 ) + { + printf( "Assertion failed: parent_nt_in parent_n_way != 0\n" ); + bli_abort(); + } + + // Compute: + // - the number of threads inside the new child comm, + // - the current thread's id within the new communicator, + // - the current thread's work id, given the ways of parallelism + // to be obtained within the next loop. + const dim_t child_nt_in = bli_rntm_calc_num_threads_in( bszid_chl, rntm ); + const dim_t child_n_way = bli_rntm_ways_for( *bszid_chl, rntm ); + const dim_t child_comm_id = parent_comm_id % child_nt_in; + const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way ); + +//printf( "thread %d: child_n_way = %d child_nt_in = %d parent_n_way = %d (bszid = %d->%d)\n", (int)child_comm_id, (int)child_nt_in, (int)child_n_way, (int)parent_n_way, (int)bli_cntl_bszid( cntl_par ), (int)bszid_chl ); + + // The parent's chief thread creates a temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + err_t r_val; + + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ), &r_val ); + else + new_comms = static_comms; + } + + // Broadcast the temporary array to all threads in the parent's + // communicator. + new_comms = bli_thread_broadcast( thread_par, new_comms ); + + // Chiefs in the child communicator allocate the communicator + // object and store it in the array element corresponding to the + // parent's work id. + if ( child_comm_id == 0 ) + new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); + + bli_thread_barrier( thread_par ); + + // All threads create a new thrinfo_t node using the communicator + // that was created by their chief, as identified by parent_work_id. + thrinfo_t* thread_chl = bli_thrinfo_create + ( + rntm, // rntm + new_comms[ parent_work_id ], // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + TRUE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); + + bli_thread_barrier( thread_par ); + + // The parent's chief thread frees the temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + bli_free_intl( new_comms ); + } + + return thread_chl; + } +} + diff --git a/frame/1m/packm/bli_packm_unb_var1.h b/frame/thread/bli_thrinfo_sup.h similarity index 74% rename from frame/1m/packm/bli_packm_unb_var1.h rename to frame/thread/bli_thrinfo_sup.h index 8960c8661a..0be035cf87 100644 --- a/frame/1m/packm/bli_packm_unb_var1.h +++ b/frame/thread/bli_thrinfo_sup.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,35 +33,34 @@ */ -void bli_packm_unb_var1 +#ifndef BLIS_THRINFO_SUP_H +#define BLIS_THRINFO_SUP_H + +// +// Prototypes for level-3 thrinfo sup functions. +// + +void bli_thrinfo_sup_grow ( - obj_t* c, - obj_t* p, - cntx_t* cntx, - cntl_t* cntl, + rntm_t* rntm, + bszid_t* bszid_par, thrinfo_t* thread ); - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - trans_t transc, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - cntx_t* cntx \ +thrinfo_t* bli_thrinfo_sup_rgrow + ( + rntm_t* rntm, + bszid_t* bszid_par, + bszid_t* bszid_cur, + thrinfo_t* thread_par ); -INSERT_GENTPROT_BASIC0( packm_unb_var1 ) +thrinfo_t* bli_thrinfo_sup_create_for_cntl + ( + rntm_t* rntm, + bszid_t* bszid_par, + bszid_t* bszid_chl, + thrinfo_t* thread_par + ); +#endif diff --git a/frame/thread/old/bli_mutex.h b/frame/thread/old/bli_mutex.h index 95d3356224..de9f720e80 100644 --- a/frame/thread/old/bli_mutex.h +++ b/frame/thread/old/bli_mutex.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/thread/old/bli_mutex_openmp.h b/frame/thread/old/bli_mutex_openmp.h index f092d73468..9aaa3c79fb 100644 --- a/frame/thread/old/bli_mutex_openmp.h +++ b/frame/thread/old/bli_mutex_openmp.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/thread/old/bli_mutex_pthreads.h b/frame/thread/old/bli_mutex_pthreads.h index 7c87dab47c..2053e61284 100644 --- a/frame/thread/old/bli_mutex_pthreads.h +++ b/frame/thread/old/bli_mutex_pthreads.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/thread/old/bli_mutex_single.h b/frame/thread/old/bli_mutex_single.h index 0c8db236be..b57d7bba32 100644 --- a/frame/thread/old/bli_mutex_single.h +++ b/frame/thread/old/bli_mutex_single.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/util/bli_util.h b/frame/util/bli_util.h index 6c34ebc676..d7e623a43a 100644 --- a/frame/util/bli_util.h +++ b/frame/util/bli_util.h @@ -37,18 +37,22 @@ // Prototype object APIs (expert and non-expert). #include "bli_oapi_ex.h" #include "bli_util_oapi.h" +#include "bli_xapi_undef.h" #include "bli_oapi_ba.h" #include "bli_util_oapi.h" +#include "bli_xapi_undef.h" // Prototype typed APIs (expert and non-expert). #include "bli_tapi_ex.h" #include "bli_util_tapi.h" #include "bli_util_ft.h" +#include "bli_xapi_undef.h" #include "bli_tapi_ba.h" #include "bli_util_tapi.h" #include "bli_util_ft.h" +#include "bli_xapi_undef.h" // Generate function pointer arrays for tapi functions (expert only). #include "bli_util_fpa.h" diff --git a/frame/util/bli_util_check.c b/frame/util/bli_util_check.c index ae4ebb4612..3693ea39c1 100644 --- a/frame/util/bli_util_check.c +++ b/frame/util/bli_util_check.c @@ -108,18 +108,16 @@ GENFRONT( normim ) \ void PASTEMAC(opname,_check) \ ( \ - FILE* file, \ - char* s1, \ - obj_t* x, \ - char* format, \ - char* s2 \ + obj_t* x \ ) \ { \ - bli_utilm_fprint_check( file, s1, x, format, s2 ); \ + bli_utilm_rand_check( x ); \ } -GENFRONT( fprintv ) -GENFRONT( fprintm ) +GENFRONT( randv ) +GENFRONT( randnv ) +GENFRONT( randm ) +GENFRONT( randnm ) #undef GENFRONT @@ -127,16 +125,32 @@ GENFRONT( fprintm ) \ void PASTEMAC(opname,_check) \ ( \ - obj_t* x \ + obj_t* x, \ + obj_t* scale, \ + obj_t* sumsq \ ) \ { \ - bli_utilm_rand_check( x ); \ + bli_utilv_sumsqv_check( x, scale, sumsq ); \ } -GENFRONT( randv ) -GENFRONT( randnv ) -GENFRONT( randm ) -GENFRONT( randnm ) +GENFRONT( sumsqv ) + +// ----------------------------------------------------------------------------- + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + obj_t* chi, \ + obj_t* psi, \ + bool* is_eq \ + ) \ +{ \ + bli_l0_xxbsc_check( chi, psi, is_eq ); \ +} + +GENFRONT( eqsc ) #undef GENFRONT @@ -145,15 +159,49 @@ GENFRONT( randnm ) void PASTEMAC(opname,_check) \ ( \ obj_t* x, \ - obj_t* scale, \ - obj_t* sumsq \ + obj_t* y, \ + bool* is_eq \ ) \ { \ - bli_utilv_sumsqv_check( x, scale, sumsq ); \ + bli_l1v_xy_check( x, y ); \ } -GENFRONT( sumsqv ) +GENFRONT( eqv ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + obj_t* x, \ + obj_t* y, \ + bool* is_eq \ + ) \ +{ \ + bli_l1m_xy_check( x, y ); \ +} + +GENFRONT( eqm ) + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + FILE* file, \ + char* s1, \ + obj_t* x, \ + char* format, \ + char* s2 \ + ) \ +{ \ + bli_utilm_fprint_check( file, s1, x, format, s2 ); \ +} + +GENFRONT( fprintv ) +GENFRONT( fprintm ) // ----------------------------------------------------------------------------- diff --git a/frame/util/bli_util_check.h b/frame/util/bli_util_check.h index a789211c96..866a2cd895 100644 --- a/frame/util/bli_util_check.h +++ b/frame/util/bli_util_check.h @@ -90,22 +90,6 @@ GENPROT( normfm ) GENPROT( normim ) -#undef GENPROT -#define GENPROT( opname ) \ -\ -void PASTEMAC(opname,_check) \ - ( \ - FILE* file, \ - char* s1, \ - obj_t* x, \ - char* format, \ - char* s2 \ - ); - -GENPROT( fprintv ) -GENPROT( fprintm ) - - #undef GENPROT #define GENPROT( opname ) \ \ @@ -132,6 +116,49 @@ void PASTEMAC(opname,_check) \ GENPROT( sumsqv ) +// ----------------------------------------------------------------------------- + +#undef GENTPROT +#define GENTPROT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + obj_t* chi, \ + obj_t* psi, \ + bool* is_eq \ + ); + +GENTPROT( eqsc ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + obj_t* x, \ + obj_t* y, \ + bool* is_eq \ + ); + +GENPROT( eqv ) +GENPROT( eqm ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + FILE* file, \ + char* s1, \ + obj_t* x, \ + char* format, \ + char* s2 \ + ); + +GENPROT( fprintv ) +GENPROT( fprintm ) // ----------------------------------------------------------------------------- diff --git a/frame/util/bli_util_fpa.c b/frame/util/bli_util_fpa.c index b68f608eb5..fba513fae1 100644 --- a/frame/util/bli_util_fpa.c +++ b/frame/util/bli_util_fpa.c @@ -66,11 +66,19 @@ GENFRONT( randm ) GENFRONT( randnm ) GENFRONT( sumsqv ) +// ----------------------------------------------------------------------------- + +// Operations with only basic interfaces. #undef GENFRONT #define GENFRONT( opname ) \ \ -GENARRAY_FPA( void*, opname ); \ +/* +GENARRAY_FPA( void_fp, opname ); \ +*/ \ +\ +GENARRAY_FPA( PASTECH(opname,_vft), \ + PASTECH0(opname) ); \ \ PASTECH(opname,_vft) \ PASTEMAC(opname,_qfp)( num_t dt ) \ @@ -78,6 +86,9 @@ PASTEMAC(opname,_qfp)( num_t dt ) \ return PASTECH(opname,_fpa)[ dt ]; \ } +GENFRONT( eqsc ) +GENFRONT( eqv ) +GENFRONT( eqm ) GENFRONT( fprintv ) GENFRONT( fprintm ) //GENFRONT( printv ) diff --git a/frame/util/bli_util_fpa.h b/frame/util/bli_util_fpa.h index 3eb2c48682..9ed6a4cf71 100644 --- a/frame/util/bli_util_fpa.h +++ b/frame/util/bli_util_fpa.h @@ -52,16 +52,13 @@ GENPROT( normiv ) GENPROT( norm1m ) GENPROT( normfm ) GENPROT( normim ) -GENPROT( fprintv ) -GENPROT( fprintm ) -//GENPROT( printv ) -//GENPROT( printm ) GENPROT( randv ) GENPROT( randnv ) GENPROT( randm ) GENPROT( randnm ) GENPROT( sumsqv ) +// ----------------------------------------------------------------------------- #undef GENPROT #define GENPROT( opname ) \ @@ -69,6 +66,9 @@ GENPROT( sumsqv ) PASTECH(opname,_vft) \ PASTEMAC(opname,_qfp)( num_t dt ); +GENPROT( eqsc ) +GENPROT( eqv ) +GENPROT( eqm ) GENPROT( fprintv ) GENPROT( fprintm ) //GENPROT( printv ) diff --git a/frame/util/bli_util_ft.h b/frame/util/bli_util_ft.h index c4f4f73d0b..673f4782aa 100644 --- a/frame/util/bli_util_ft.h +++ b/frame/util/bli_util_ft.h @@ -191,3 +191,62 @@ typedef void (*PASTECH3(ch,opname,EX_SUF,tsuf)) \ INSERT_GENTDEFR( sumsqv ) +// ----------------------------------------------------------------------------- + +// Operations with only basic interfaces. + +#ifdef BLIS_TAPI_BASIC + +// eqsc + +#undef GENTDEF +#define GENTDEF( ctype, ch, opname, tsuf ) \ +\ +typedef void (*PASTECH2(ch,opname,tsuf)) \ + ( \ + conj_t conjchi, \ + ctype* chi, \ + ctype* psi, \ + bool* is_eq \ + ); + +INSERT_GENTDEF( eqsc ) + +// eqv + +#undef GENTDEF +#define GENTDEF( ctype, ch, opname, tsuf ) \ +\ +typedef void (*PASTECH2(ch,opname,tsuf)) \ + ( \ + conj_t conjx, \ + dim_t n, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + bool* is_eq \ + ); + +INSERT_GENTDEF( eqv ) + +// eqm + +#undef GENTDEF +#define GENTDEF( ctype, ch, opname, tsuf ) \ +\ +typedef void (*PASTECH2(ch,opname,tsuf)) \ + ( \ + doff_t diagoffx, \ + diag_t diagx, \ + uplo_t uplox, \ + trans_t transx, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype* y, inc_t rs_y, inc_t cs_y, \ + bool* is_eq \ + ); + +INSERT_GENTDEF( eqm ) + +#endif // #ifdef BLIS_OAPI_BASIC + diff --git a/frame/util/bli_util_oapi.c b/frame/util/bli_util_oapi.c index f9f9b4c93e..afd221a587 100644 --- a/frame/util/bli_util_oapi.c +++ b/frame/util/bli_util_oapi.c @@ -66,17 +66,17 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, asum ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ ( \ - n, \ - buf_x, incx, \ - buf_asum, \ - cntx, \ - rntm \ + n, \ + buf_x, incx, \ + buf_asum, \ + cntx, \ + rntm \ ); \ } @@ -108,17 +108,17 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( a ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ ( \ - uploa, \ - m, \ - buf_a, rs_a, cs_a, \ - cntx, \ - rntm \ + uploa, \ + m, \ + buf_a, rs_a, cs_a, \ + cntx, \ + rntm \ ); \ } @@ -152,17 +152,17 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, norm ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ ( \ - n, \ - buf_x, incx, \ - buf_norm, \ - cntx, \ - rntm \ + n, \ + buf_x, incx, \ + buf_norm, \ + cntx, \ + rntm \ ); \ } @@ -201,21 +201,21 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, norm ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ ( \ - diagoffx, \ - diagx, \ - uplox, \ - m, \ - n, \ - buf_x, rs_x, cs_x, \ - buf_norm, \ - cntx, \ - rntm \ + diagoffx, \ + diagx, \ + uplox, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + buf_norm, \ + cntx, \ + rntm \ ); \ } @@ -229,11 +229,7 @@ GENFRONT( normim ) \ void PASTEMAC(opname,EX_SUF) \ ( \ - FILE* file, \ - char* s1, \ - obj_t* x, \ - char* format, \ - char* s2 \ + obj_t* x \ BLIS_OAPI_EX_PARAMS \ ) \ { \ @@ -248,31 +244,24 @@ void PASTEMAC(opname,EX_SUF) \ inc_t incx = bli_obj_vector_inc( x ); \ \ if ( bli_error_checking_is_enabled() ) \ - PASTEMAC(opname,_check)( file, s1, x, format, s2 ); \ -\ - /* Handle constants up front. */ \ - if ( dt == BLIS_CONSTANT ) \ - { \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ - } \ + PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ - PASTECH(opname,_vft) f = \ - PASTEMAC(opname,_qfp)( dt ); \ + void* for function arguments instead of typed pointers. */ \ + PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ + PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ ( \ - file, \ - s1, \ - n, \ - buf_x, incx, \ - format, \ - s2 \ + n, \ + buf_x, incx, \ + cntx, \ + rntm \ ); \ } -GENFRONT( fprintv ) +GENFRONT( randv ) +GENFRONT( randnv ) #undef GENFRONT @@ -280,11 +269,7 @@ GENFRONT( fprintv ) \ void PASTEMAC(opname,EX_SUF) \ ( \ - FILE* file, \ - char* s1, \ - obj_t* x, \ - char* format, \ - char* s2 \ + obj_t* x \ BLIS_OAPI_EX_PARAMS \ ) \ { \ @@ -294,6 +279,8 @@ void PASTEMAC(opname,EX_SUF) \ \ num_t dt = bli_obj_dt( x ); \ \ + doff_t diagoffx = bli_obj_diag_offset( x ); \ + uplo_t uplox = bli_obj_uplo( x ); \ dim_t m = bli_obj_length( x ); \ dim_t n = bli_obj_width( x ); \ void* buf_x = bli_obj_buffer_at_off( x ); \ @@ -301,58 +288,37 @@ void PASTEMAC(opname,EX_SUF) \ inc_t cs_x = bli_obj_col_stride( x ); \ \ if ( bli_error_checking_is_enabled() ) \ - PASTEMAC(opname,_check)( file, s1, x, format, s2 ); \ -\ - /* Handle constants up front. */ \ - if ( dt == BLIS_CONSTANT ) \ - { \ - float* sp = bli_obj_buffer_for_const( BLIS_FLOAT, x ); \ - double* dp = bli_obj_buffer_for_const( BLIS_DOUBLE, x ); \ - scomplex* cp = bli_obj_buffer_for_const( BLIS_SCOMPLEX, x ); \ - dcomplex* zp = bli_obj_buffer_for_const( BLIS_DCOMPLEX, x ); \ - gint_t* ip = bli_obj_buffer_for_const( BLIS_INT, x ); \ -\ - fprintf( file, "%s\n", s1 ); \ - fprintf( file, " float: %9.2e\n", bli_sreal( *sp ) ); \ - fprintf( file, " double: %9.2e\n", bli_dreal( *dp ) ); \ - fprintf( file, " scomplex: %9.2e + %9.2e\n", bli_creal( *cp ), \ - bli_cimag( *cp ) ); \ - fprintf( file, " dcomplex: %9.2e + %9.2e\n", bli_zreal( *zp ), \ - bli_zimag( *zp ) ); \ - fprintf( file, " int: %ld\n", ( long )(*ip) ); \ - fprintf( file, "\n" ); \ - return; \ - } \ + PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ - PASTECH(opname,_vft) f = \ - PASTEMAC(opname,_qfp)( dt ); \ + void* for function arguments instead of typed pointers. */ \ + PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ + PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ f \ ( \ - file, \ - s1, \ - m, \ - n, \ - buf_x, rs_x, cs_x, \ - format, \ - s2 \ + diagoffx, \ + uplox, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + cntx, \ + rntm \ ); \ } -GENFRONT( fprintm ) +GENFRONT( randm ) +GENFRONT( randnm ) #undef GENFRONT -#define GENFRONT( opname, varname ) \ +#define GENFRONT( opname ) \ \ void PASTEMAC(opname,EX_SUF) \ ( \ - char* s1, \ obj_t* x, \ - char* format, \ - char* s2 \ + obj_t* scale, \ + obj_t* sumsq \ BLIS_OAPI_EX_PARAMS \ ) \ { \ @@ -360,155 +326,348 @@ void PASTEMAC(opname,EX_SUF) \ \ BLIS_OAPI_EX_DECLS \ \ - /* Suppress compiler warning about unused variables. */ \ - ( void )cntx; \ + num_t dt = bli_obj_dt( x ); \ \ - /* Invoke the typed function. */ \ - PASTEMAC0(varname) \ + dim_t n = bli_obj_vector_dim( x ); \ + void* buf_x = bli_obj_buffer_at_off( x ); \ + inc_t incx = bli_obj_vector_inc( x ); \ + void* buf_scale = bli_obj_buffer_at_off( scale ); \ + void* buf_sumsq = bli_obj_buffer_at_off( sumsq ); \ +\ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( x, scale, sumsq ); \ +\ + /* Query a type-specific function pointer, except one that uses + void* for function arguments instead of typed pointers. */ \ + PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ + PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ +\ + f \ ( \ - stdout, \ - s1, \ - x, \ - format, \ - s2 \ + n, \ + buf_x, incx, \ + buf_scale, \ + buf_sumsq, \ + cntx, \ + rntm \ ); \ } -GENFRONT( printv, fprintv ) -GENFRONT( printm, fprintm ) +GENFRONT( sumsqv ) + +// ----------------------------------------------------------------------------- + +// Operations with only basic interfaces. +#ifdef BLIS_OAPI_BASIC #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ - obj_t* x \ - BLIS_OAPI_EX_PARAMS \ + obj_t* chi, \ + obj_t* psi, \ + bool* is_eq \ ) \ { \ bli_init_once(); \ \ - BLIS_OAPI_EX_DECLS \ + num_t dt_chi = bli_obj_dt( chi ); \ + num_t dt_psi = bli_obj_dt( psi ); \ + num_t dt; \ +\ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( chi, psi, is_eq ); \ +\ + /* Decide which datatype will be used to query the buffer from the + constant object (if there is one). */ \ + if ( bli_is_constant( dt_psi ) ) dt = dt_chi; \ + else dt = dt_psi; \ +\ + /* If chi and psi are both constants, then we compare only the dcomplex + fields. */ \ + if ( bli_is_constant( dt ) ) dt = BLIS_DCOMPLEX; \ +\ + void* buf_chi = bli_obj_buffer_for_1x1( dt, chi ); \ + void* buf_psi = bli_obj_buffer_for_1x1( dt, psi ); \ +\ + /* Integer objects are handled separately. */ \ + if ( bli_is_int( dt ) ) \ + { \ + *is_eq = bli_ieqa( buf_chi, buf_psi ); \ + return; \ + } \ +\ + /* Query the conj status of each object and use the two to come up with a + single "net" conj_t value. */ \ + conj_t conjchi = bli_obj_conj_status( chi ); \ + conj_t conjpsi = bli_obj_conj_status( psi ); \ + conj_t conj = bli_apply_conj( conjchi, conjpsi ); \ +\ + /* Query a type-specific function pointer, except one that uses + void* for function arguments instead of typed pointers. */ \ + PASTECH(opname,_vft) f = \ + PASTEMAC(opname,_qfp)( dt ); \ +\ + f \ + ( \ + conj, \ + buf_chi, \ + buf_psi, \ + is_eq \ + ); \ +} + +GENFRONT( eqsc ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC0(opname) \ + ( \ + obj_t* x, \ + obj_t* y, \ + bool* is_eq \ + ) \ +{ \ + bli_init_once(); \ \ num_t dt = bli_obj_dt( x ); \ \ dim_t n = bli_obj_vector_dim( x ); \ void* buf_x = bli_obj_buffer_at_off( x ); \ - inc_t incx = bli_obj_vector_inc( x ); \ + inc_t inc_x = bli_obj_vector_inc( x ); \ + void* buf_y = bli_obj_buffer_at_off( y ); \ + inc_t inc_y = bli_obj_vector_inc( y ); \ \ if ( bli_error_checking_is_enabled() ) \ - PASTEMAC(opname,_check)( x ); \ + PASTEMAC(opname,_check)( x, y, is_eq ); \ +\ + /* Query the conj status of each object and use the two to come up with a + single "net" conj_t value. */ \ + conj_t conjx = bli_obj_conj_status( x ); \ + conj_t conjy = bli_obj_conj_status( y ); \ + conj_t conj = bli_apply_conj( conjx, conjy ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ - PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ - PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ + void* for function arguments instead of typed pointers. */ \ + PASTECH(opname,_vft) f = \ + PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ - n, \ - buf_x, incx, \ - cntx, \ - rntm \ + conj, \ + n, \ + buf_x, inc_x, \ + buf_y, inc_y, \ + is_eq \ ); \ } -GENFRONT( randv ) -GENFRONT( randnv ) +GENFRONT( eqv ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ - obj_t* x \ - BLIS_OAPI_EX_PARAMS \ + obj_t* x, \ + obj_t* y, \ + bool* is_eq \ ) \ { \ bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ \ num_t dt = bli_obj_dt( x ); \ \ doff_t diagoffx = bli_obj_diag_offset( x ); \ + diag_t diagx = bli_obj_diag( x ); \ uplo_t uplox = bli_obj_uplo( x ); \ - dim_t m = bli_obj_length( x ); \ - dim_t n = bli_obj_width( x ); \ + dim_t m = bli_obj_length( y ); \ + dim_t n = bli_obj_width( y ); \ void* buf_x = bli_obj_buffer_at_off( x ); \ inc_t rs_x = bli_obj_row_stride( x ); \ inc_t cs_x = bli_obj_col_stride( x ); \ + void* buf_y = bli_obj_buffer_at_off( y ); \ + inc_t rs_y = bli_obj_row_stride( y ); \ + inc_t cs_y = bli_obj_col_stride( y ); \ \ if ( bli_error_checking_is_enabled() ) \ - PASTEMAC(opname,_check)( x ); \ + PASTEMAC(opname,_check)( x, y, is_eq ); \ +\ + /* Query the combined trans and conj status of each object and use the two + to come up with a single "net" trans_t value. */ \ + trans_t transx = bli_obj_conjtrans_status( x ); \ + trans_t transy = bli_obj_conjtrans_status( y ); \ + trans_t trans = bli_apply_trans( transy, transx ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ - PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ - PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ + void* for function arguments instead of typed pointers. */ \ + PASTECH(opname,_vft) f = \ + PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ - diagoffx, \ - uplox, \ - m, \ - n, \ - buf_x, rs_x, cs_x, \ - cntx, \ - rntm \ + diagoffx, \ + diagx, \ + uplox, \ + trans, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + buf_y, rs_y, cs_y, \ + is_eq \ ); \ } -GENFRONT( randm ) -GENFRONT( randnm ) +GENFRONT( eqm ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ + FILE* file, \ + char* s1, \ obj_t* x, \ - obj_t* scale, \ - obj_t* sumsq \ - BLIS_OAPI_EX_PARAMS \ + char* format, \ + char* s2 \ ) \ { \ bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ \ num_t dt = bli_obj_dt( x ); \ \ dim_t n = bli_obj_vector_dim( x ); \ void* buf_x = bli_obj_buffer_at_off( x ); \ inc_t incx = bli_obj_vector_inc( x ); \ - void* buf_scale = bli_obj_buffer_at_off( scale ); \ - void* buf_sumsq = bli_obj_buffer_at_off( sumsq ); \ \ if ( bli_error_checking_is_enabled() ) \ - PASTEMAC(opname,_check)( x, scale, sumsq ); \ + PASTEMAC(opname,_check)( file, s1, x, format, s2 ); \ +\ + /* Handle constants up front. */ \ + if ( dt == BLIS_CONSTANT ) \ + { \ + bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ + } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ - PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ - PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ + void* for function arguments instead of typed pointers. */ \ + PASTECH(opname,_vft) f = \ + PASTEMAC(opname,_qfp)( dt ); \ \ f \ ( \ - n, \ - buf_x, incx, \ - buf_scale, \ - buf_sumsq, \ - cntx, \ - rntm \ + file, \ + s1, \ + n, \ + buf_x, incx, \ + format, \ + s2 \ ); \ } -GENFRONT( sumsqv ) +GENFRONT( fprintv ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC0(opname) \ + ( \ + FILE* file, \ + char* s1, \ + obj_t* x, \ + char* format, \ + char* s2 \ + ) \ +{ \ + bli_init_once(); \ +\ + num_t dt = bli_obj_dt( x ); \ +\ + dim_t m = bli_obj_length( x ); \ + dim_t n = bli_obj_width( x ); \ + void* buf_x = bli_obj_buffer_at_off( x ); \ + inc_t rs_x = bli_obj_row_stride( x ); \ + inc_t cs_x = bli_obj_col_stride( x ); \ +\ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( file, s1, x, format, s2 ); \ +\ + /* Handle constants up front. */ \ + if ( dt == BLIS_CONSTANT ) \ + { \ + float* sp = bli_obj_buffer_for_const( BLIS_FLOAT, x ); \ + double* dp = bli_obj_buffer_for_const( BLIS_DOUBLE, x ); \ + scomplex* cp = bli_obj_buffer_for_const( BLIS_SCOMPLEX, x ); \ + dcomplex* zp = bli_obj_buffer_for_const( BLIS_DCOMPLEX, x ); \ + gint_t* ip = bli_obj_buffer_for_const( BLIS_INT, x ); \ +\ + fprintf( file, "%s\n", s1 ); \ + fprintf( file, " float: %9.2e\n", bli_sreal( *sp ) ); \ + fprintf( file, " double: %9.2e\n", bli_dreal( *dp ) ); \ + fprintf( file, " scomplex: %9.2e + %9.2e\n", bli_creal( *cp ), \ + bli_cimag( *cp ) ); \ + fprintf( file, " dcomplex: %9.2e + %9.2e\n", bli_zreal( *zp ), \ + bli_zimag( *zp ) ); \ + fprintf( file, " int: %ld\n", ( long )(*ip) ); \ + fprintf( file, "\n" ); \ + return; \ + } \ +\ + /* Query a type-specific function pointer, except one that uses + void* for function arguments instead of typed pointers. */ \ + PASTECH(opname,_vft) f = \ + PASTEMAC(opname,_qfp)( dt ); \ +\ + f \ + ( \ + file, \ + s1, \ + m, \ + n, \ + buf_x, rs_x, cs_x, \ + format, \ + s2 \ + ); \ +} + +GENFRONT( fprintm ) + + +#undef GENFRONT +#define GENFRONT( opname, varname ) \ +\ +void PASTEMAC0(opname) \ + ( \ + char* s1, \ + obj_t* x, \ + char* format, \ + char* s2 \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Invoke the typed function. */ \ + PASTEMAC0(varname) \ + ( \ + stdout, \ + s1, \ + x, \ + format, \ + s2 \ + ); \ +} + +GENFRONT( printv, fprintv ) +GENFRONT( printm, fprintm ) +#endif // #ifdef BLIS_OAPI_BASIC #endif diff --git a/frame/util/bli_util_oapi.h b/frame/util/bli_util_oapi.h index 1acce16065..92ce6c95f7 100644 --- a/frame/util/bli_util_oapi.h +++ b/frame/util/bli_util_oapi.h @@ -99,16 +99,12 @@ GENPROT( normim ) \ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ - FILE* file, \ - char* s1, \ - obj_t* x, \ - char* format, \ - char* s2 \ + obj_t* x \ BLIS_OAPI_EX_PARAMS \ ); -GENPROT( fprintv ) -GENPROT( fprintm ) +GENPROT( randv ) +GENPROT( randnv ) #undef GENPROT @@ -116,15 +112,12 @@ GENPROT( fprintm ) \ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ - char* s1, \ - obj_t* x, \ - char* format, \ - char* s2 \ + obj_t* x \ BLIS_OAPI_EX_PARAMS \ ); -GENPROT( printv ) -GENPROT( printm ) +GENPROT( randm ) +GENPROT( randnm ) #undef GENPROT @@ -132,37 +125,92 @@ GENPROT( printm ) \ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ - obj_t* x \ + obj_t* x, \ + obj_t* scale, \ + obj_t* sumsq \ BLIS_OAPI_EX_PARAMS \ ); -GENPROT( randv ) -GENPROT( randnv ) +GENPROT( sumsqv ) + +// ----------------------------------------------------------------------------- + +// Operations with basic interfaces only. +#ifdef BLIS_OAPI_BASIC +/* #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ - obj_t* x \ - BLIS_OAPI_EX_PARAMS \ + obj_t* chi, \ + obj_t* psi, \ + bool* is_eq \ ); -GENPROT( randm ) -GENPROT( randnm ) +GENPROT( eqsc ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* x, \ - obj_t* scale, \ - obj_t* sumsq \ - BLIS_OAPI_EX_PARAMS \ + obj_t* y, \ + bool* is_eq \ ); -GENPROT( sumsqv ) +GENPROT( eqv ) +*/ + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ + ( \ + obj_t* x, \ + obj_t* y, \ + bool* is_eq \ + ); + +GENPROT( eqsc ) +GENPROT( eqv ) +GENPROT( eqm ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ + ( \ + FILE* file, \ + char* s1, \ + obj_t* x, \ + char* format, \ + char* s2 \ + ); + +GENPROT( fprintv ) +GENPROT( fprintm ) + + +#undef GENPROT +#define GENPROT( opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ + ( \ + char* s1, \ + obj_t* x, \ + char* format, \ + char* s2 \ + ); + +GENPROT( printv ) +GENPROT( printm ) + +#endif // #ifdef BLIS_OAPI_BASIC diff --git a/frame/util/bli_util_tapi.c b/frame/util/bli_util_tapi.c index 6bef27d43a..ca0b3c279d 100644 --- a/frame/util/bli_util_tapi.c +++ b/frame/util/bli_util_tapi.c @@ -213,71 +213,137 @@ INSERT_GENTFUNCR_BASIC0( normfm ) INSERT_GENTFUNCR_BASIC0( normim ) -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, varname ) \ +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ void PASTEMAC2(ch,opname,EX_SUF) \ ( \ - char* s1, \ - dim_t n, \ - void* x, inc_t incx, \ - char* format, \ - char* s2 \ + dim_t n, \ + ctype* x, inc_t incx \ + BLIS_TAPI_EX_PARAMS \ ) \ { \ bli_init_once(); \ \ - PASTEMAC(ch,varname) \ - ( \ - stdout, \ - s1, \ - n, \ - x, incx, \ - format, \ - s2 \ - ); \ + BLIS_TAPI_EX_DECLS \ +\ + /* If the vector length is zero, return early. */ \ + if ( bli_zero_dim1( n ) ) return; \ +\ + /* Obtain a valid context from the gks if necessary. */ \ + /*if ( cntx == NULL ) cntx = bli_gks_query_cntx();*/ \ +\ + ctype_r norm; \ +\ + /* Set the norm to zero. */ \ + PASTEMAC(chr,set0s)( norm ); \ +\ + /* Iterate at least once, but continue iterating until the norm is not zero. */ \ + while ( PASTEMAC(chr,eq0)( norm ) ) \ + { \ + /* Invoke the helper variant, which loops over the appropriate kernel + to implement the current operation. */ \ + PASTEMAC2(ch,opname,_unb_var1) \ + ( \ + n, \ + x, incx, \ + cntx, \ + rntm \ + ); \ +\ + /* Check the 1-norm of the randomzied vector. In the unlikely event that + the 1-norm is zero, it means that *all* elements are zero, in which + case we want to re-randomize until the 1-norm is not zero. */ \ + PASTEMAC2(ch,norm1v,BLIS_TAPI_EX_SUF) \ + ( \ + n, \ + x, incx, \ + &norm, \ + cntx, \ + rntm \ + ); \ + } \ } -INSERT_GENTFUNC_BASIC_I( printv, fprintv ) +INSERT_GENTFUNCR_BASIC0( randv ) +INSERT_GENTFUNCR_BASIC0( randnv ) -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, varname ) \ +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ void PASTEMAC2(ch,opname,EX_SUF) \ ( \ - char* s1, \ - dim_t m, \ - dim_t n, \ - void* x, inc_t rs_x, inc_t cs_x, \ - char* format, \ - char* s2 \ + doff_t diagoffx, \ + uplo_t uplox, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x \ + BLIS_TAPI_EX_PARAMS \ ) \ { \ bli_init_once(); \ \ - PASTEMAC(ch,varname) \ - ( \ - stdout, \ - s1, \ - m, \ - n, \ - x, rs_x, cs_x, \ - format, \ - s2 \ - ); \ + BLIS_TAPI_EX_DECLS \ +\ + /* If either dimension is zero, return early. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* Obtain a valid context from the gks if necessary. */ \ + /*if ( cntx == NULL ) cntx = bli_gks_query_cntx();*/ \ +\ + ctype_r norm; \ +\ + /* Set the norm to zero. */ \ + PASTEMAC(chr,set0s)( norm ); \ +\ + /* Iterate at least once, but continue iterating until the norm is not zero. */ \ + while ( PASTEMAC(chr,eq0)( norm ) ) \ + { \ + /* Invoke the helper variant, which loops over the appropriate kernel + to implement the current operation. */ \ + PASTEMAC2(ch,opname,_unb_var1) \ + ( \ + diagoffx, \ + uplox, \ + m, \ + n, \ + x, rs_x, cs_x, \ + cntx, \ + rntm \ + ); \ +\ + /* Check the 1-norm of the randomzied matrix. In the unlikely event that + the 1-norm is zero, it means that *all* elements are zero, in which + case we want to re-randomize until the 1-norm is not zero. */ \ + PASTEMAC2(ch,norm1m,BLIS_TAPI_EX_SUF) \ + ( \ + diagoffx, \ + BLIS_NONUNIT_DIAG, \ + uplox, \ + m, \ + n, \ + x, rs_x, cs_x, \ + &norm, \ + cntx, \ + rntm \ + ); \ + } \ } -INSERT_GENTFUNC_BASIC_I( printm, fprintm ) +INSERT_GENTFUNCR_BASIC0( randm ) +INSERT_GENTFUNCR_BASIC0( randnm ) -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ - ctype* x, inc_t incx \ + ctype* x, inc_t incx, \ + ctype_r* scale, \ + ctype_r* sumsq \ BLIS_TAPI_EX_PARAMS \ ) \ { \ @@ -285,7 +351,7 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ BLIS_TAPI_EX_DECLS \ \ - /* If the vector length is zero, return early. */ \ + /* If x is zero length, return with scale and sumsq unchanged. */ \ if ( bli_zero_dim1( n ) ) return; \ \ /* Obtain a valid context from the gks if necessary. */ \ @@ -297,92 +363,176 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ( \ n, \ x, incx, \ + scale, \ + sumsq, \ cntx, \ rntm \ ); \ } -INSERT_GENTFUNC_BASIC0( randv ) -INSERT_GENTFUNC_BASIC0( randnv ) +INSERT_GENTFUNCR_BASIC0( sumsqv ) + +// ----------------------------------------------------------------------------- + +// Operations with only basic interfaces. +#ifdef BLIS_TAPI_BASIC #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conjchi, \ + ctype* chi, \ + ctype* psi, \ + bool* is_eq \ + ) \ +{ \ + bli_init_once(); \ +\ + ctype chi_conj; \ +\ + PASTEMAC(ch,copycjs)( conjchi, *chi, chi_conj ); \ +\ + *is_eq = PASTEMAC(ch,eq)( chi_conj, *psi ); \ +} + +INSERT_GENTFUNC_BASIC0( eqsc ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conjx, \ + dim_t n, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + bool* is_eq \ + ) \ +{ \ + bli_init_once(); \ +\ + /* If x is zero length, return with a result of TRUE. */ \ + if ( bli_zero_dim1( n ) ) { *is_eq = TRUE; return; } \ +\ + /* Obtain a valid context from the gks if necessary. */ \ + /*if ( cntx == NULL ) cntx = bli_gks_query_cntx();*/ \ +\ + *is_eq = PASTEMAC2(ch,opname,_unb_var1) \ + ( \ + conjx, \ + n, \ + x, incx, \ + y, incy \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( eqv ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ ( \ doff_t diagoffx, \ + diag_t diagx, \ uplo_t uplox, \ + trans_t transx, \ dim_t m, \ dim_t n, \ - ctype* x, inc_t rs_x, inc_t cs_x \ - BLIS_TAPI_EX_PARAMS \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype* y, inc_t rs_y, inc_t cs_y, \ + bool* is_eq \ ) \ { \ bli_init_once(); \ \ - BLIS_TAPI_EX_DECLS \ -\ - /* If either dimension is zero, return early. */ \ - if ( bli_zero_dim2( m, n ) ) return; \ + /* If x has a zero dimension, return with a result of TRUE. See the + _unb_var() variant for why we return TRUE in this scenario. */ \ + if ( bli_zero_dim2( m, n ) ) { *is_eq = TRUE; return; } \ \ /* Obtain a valid context from the gks if necessary. */ \ /*if ( cntx == NULL ) cntx = bli_gks_query_cntx();*/ \ \ - /* Invoke the helper variant, which loops over the appropriate kernel - to implement the current operation. */ \ - PASTEMAC2(ch,opname,_unb_var1) \ + /* Invoke the helper variant. */ \ + *is_eq = PASTEMAC2(ch,opname,_unb_var1) \ ( \ diagoffx, \ + diagx, \ uplox, \ + transx, \ m, \ n, \ x, rs_x, cs_x, \ - cntx, \ - rntm \ + y, rs_y, cs_y \ ); \ } -INSERT_GENTFUNC_BASIC0( randm ) -INSERT_GENTFUNC_BASIC0( randnm ) +INSERT_GENTFUNC_BASIC0( eqm ) -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, varname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ - dim_t n, \ - ctype* x, inc_t incx, \ - ctype_r* scale, \ - ctype_r* sumsq \ - BLIS_TAPI_EX_PARAMS \ + char* s1, \ + dim_t n, \ + void* x, inc_t incx, \ + char* format, \ + char* s2 \ ) \ { \ bli_init_once(); \ \ - BLIS_TAPI_EX_DECLS \ -\ - /* If x is zero length, return with scale and sumsq unchanged. */ \ - if ( bli_zero_dim1( n ) ) return; \ + PASTEMAC(ch,varname) \ + ( \ + stdout, \ + s1, \ + n, \ + x, incx, \ + format, \ + s2 \ + ); \ +} + +INSERT_GENTFUNC_BASIC_I( printv, fprintv ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, varname ) \ \ - /* Obtain a valid context from the gks if necessary. */ \ - /*if ( cntx == NULL ) cntx = bli_gks_query_cntx();*/ \ +void PASTEMAC(ch,opname) \ + ( \ + char* s1, \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + char* format, \ + char* s2 \ + ) \ +{ \ + bli_init_once(); \ \ - /* Invoke the helper variant, which loops over the appropriate kernel - to implement the current operation. */ \ - PASTEMAC2(ch,opname,_unb_var1) \ + PASTEMAC(ch,varname) \ ( \ + stdout, \ + s1, \ + m, \ n, \ - x, incx, \ - scale, \ - sumsq, \ - cntx, \ - rntm \ + x, rs_x, cs_x, \ + format, \ + s2 \ ); \ } -INSERT_GENTFUNCR_BASIC0( sumsqv ) +INSERT_GENTFUNC_BASIC_I( printm, fprintm ) + +#endif // #ifdef BLIS_TAPI_BASIC #endif diff --git a/frame/util/bli_util_tapi.h b/frame/util/bli_util_tapi.h index c35702cbc4..43fbbdb063 100644 --- a/frame/util/bli_util_tapi.h +++ b/frame/util/bli_util_tapi.h @@ -103,37 +103,6 @@ INSERT_GENTPROTR_BASIC0( normfm ) INSERT_GENTPROTR_BASIC0( normim ) -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ - ( \ - char* s1, \ - dim_t n, \ - void* x, inc_t incx, \ - char* format, \ - char* s2 \ - ); - -INSERT_GENTPROT_BASIC0_I( printv ) - - -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ - ( \ - char* s1, \ - dim_t m, \ - dim_t n, \ - void* x, inc_t rs_x, inc_t cs_x, \ - char* format, \ - char* s2 \ - ); - -INSERT_GENTPROT_BASIC0_I( printm ) - - #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ @@ -179,4 +148,89 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ INSERT_GENTPROTR_BASIC0( sumsqv ) +// ----------------------------------------------------------------------------- + +// Operations with basic interfaces only. + +#ifdef BLIS_TAPI_BASIC + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + conj_t conjchi, \ + ctype* chi, \ + ctype* psi, \ + bool* is_eq \ + ); + +INSERT_GENTPROT_BASIC0( eqsc ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + conj_t conjx, \ + dim_t n, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + bool* is_eq \ + ); + +INSERT_GENTPROT_BASIC0( eqv ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + doff_t diagoffx, \ + diag_t diagx, \ + uplo_t uplox, \ + trans_t transx, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype* y, inc_t rs_y, inc_t cs_y, \ + bool* is_eq \ + ); + +INSERT_GENTPROT_BASIC0( eqm ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + char* s1, \ + dim_t n, \ + void* x, inc_t incx, \ + char* format, \ + char* s2 \ + ); + +INSERT_GENTPROT_BASIC0_I( printv ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + char* s1, \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + char* format, \ + char* s2 \ + ); + +INSERT_GENTPROT_BASIC0_I( printm ) + +#endif // #ifdef BLIS_TAPI_BASIC diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index 203a63a1dd..af550681aa 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -311,7 +311,14 @@ GENTFUNCR( scomplex, float, c, s, normfv_unb_var1, sumsqv_unb_var1 ) GENTFUNCR( dcomplex, double, z, d, normfv_unb_var1, sumsqv_unb_var1 ) #undef GENTFUNCR -#ifdef FE_OVERFLOW +// We've disabled the dotv-based implementation because that method of +// computing the sum of the squares of x inherently does not check for +// overflow. Instead, we use the fallback method based on sumsqv, which +// takes care to not overflow unnecessarily (ie: takes care for the +// sqrt( sum of the squares of x ) to not overflow if the sum of the +// squares of x would normally overflow. See GitHub issue #332 for +// discussion. +#if 0 //defined(FE_OVERFLOW) && !defined(__APPLE__) #define GENTFUNCR( ctype, ctype_r, ch, chr, varname, kername ) \ \ void PASTEMAC(ch,varname) \ @@ -466,7 +473,7 @@ void PASTEMAC(ch,varname) \ /* If the absolute value of the current element exceeds that of the previous largest, save it and its index. If NaN is encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This + value that was larger than any previously seen. This behavior mimics that of LAPACK's ?lange(). */ \ if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ { \ @@ -855,85 +862,6 @@ void PASTEMAC(ch,varname) \ INSERT_GENTFUNCR_BASIC( normim_unb_var1, norm1m_unb_var1 ) -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - FILE* file, \ - char* s1, \ - dim_t n, \ - ctype* x, inc_t incx, \ - char* format, \ - char* s2 \ - ) \ -{ \ - dim_t i; \ - ctype* chi1; \ - char default_spec[32] = PASTEMAC(ch,formatspec)(); \ -\ - if ( format == NULL ) format = default_spec; \ -\ - chi1 = x; \ -\ - fprintf( file, "%s\n", s1 ); \ -\ - for ( i = 0; i < n; ++i ) \ - { \ - PASTEMAC(ch,fprints)( file, format, *chi1 ); \ - fprintf( file, "\n" ); \ -\ - chi1 += incx; \ - } \ -\ - fprintf( file, "%s\n", s2 ); \ -} - -INSERT_GENTFUNC_BASIC0_I( fprintv ) - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - FILE* file, \ - char* s1, \ - dim_t m, \ - dim_t n, \ - ctype* x, inc_t rs_x, inc_t cs_x, \ - char* format, \ - char* s2 \ - ) \ -{ \ - dim_t i, j; \ - ctype* chi1; \ - char default_spec[32] = PASTEMAC(ch,formatspec)(); \ -\ - if ( format == NULL ) format = default_spec; \ -\ - fprintf( file, "%s\n", s1 ); \ -\ - for ( i = 0; i < m; ++i ) \ - { \ - for ( j = 0; j < n; ++j ) \ - { \ - chi1 = (( ctype* ) x) + i*rs_x + j*cs_x; \ -\ - PASTEMAC(ch,fprints)( file, format, *chi1 ); \ - fprintf( file, " " ); \ - } \ -\ - fprintf( file, "\n" ); \ - } \ -\ - fprintf( file, "%s\n", s2 ); \ - fflush( file ); \ -} - -INSERT_GENTFUNC_BASIC0_I( fprintm ) - - #undef GENTFUNC #define GENTFUNC( ctype, ch, varname, randmac ) \ \ @@ -1012,7 +940,8 @@ void PASTEMAC(ch,varname) \ \ x1 = x + (j )*ldx + (0 )*incx; \ \ - PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF) \ + /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ + PASTEMAC(ch,kername) \ ( \ n_elem, \ x1, incx, \ @@ -1039,7 +968,8 @@ void PASTEMAC(ch,varname) \ x0 = x1; \ chi1 = x1 + (n_elem-1)*incx; \ \ - PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF) \ + /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ + PASTEMAC(ch,kername) \ ( \ n_elem, \ x1, incx, \ @@ -1079,7 +1009,8 @@ void PASTEMAC(ch,varname) \ x2 = x1 + incx; \ chi1 = x1; \ \ - PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF) \ + /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ + PASTEMAC(ch,kername) \ ( \ n_elem, \ x1, incx, \ @@ -1111,8 +1042,8 @@ void PASTEMAC(ch,varname) \ } \ } -INSERT_GENTFUNC_BASIC( randm_unb_var1, randv ) -INSERT_GENTFUNC_BASIC( randnm_unb_var1, randnv ) +INSERT_GENTFUNC_BASIC( randm_unb_var1, randv_unb_var1 ) +INSERT_GENTFUNC_BASIC( randnm_unb_var1, randnv_unb_var1 ) #undef GENTFUNCR @@ -1205,3 +1136,238 @@ void PASTEMAC(ch,varname) \ INSERT_GENTFUNCR_BASIC0( sumsqv_unb_var1 ) +// ----------------------------------------------------------------------------- + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +bool PASTEMAC(ch,opname) \ + ( \ + conj_t conjx, \ + dim_t n, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy \ + ) \ +{ \ + for ( dim_t i = 0; i < n; ++i ) \ + { \ + ctype* chi1 = x + (i )*incx; \ + ctype* psi1 = y + (i )*incy; \ +\ + ctype chi1c; \ +\ + if ( bli_is_conj( conjx ) ) { PASTEMAC(ch,copyjs)( *chi1, chi1c ); } \ + else { PASTEMAC(ch,copys)( *chi1, chi1c ); } \ +\ + if ( !PASTEMAC(ch,eq)( chi1c, *psi1 ) ) \ + return FALSE; \ + } \ +\ + return TRUE; \ +} + +INSERT_GENTFUNC_BASIC0( eqv_unb_var1 ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +bool PASTEMAC(ch,opname) \ + ( \ + doff_t diagoffx, \ + diag_t diagx, \ + uplo_t uplox, \ + trans_t transx, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + uplo_t uplox_eff; \ + conj_t conjx; \ + dim_t n_iter; \ + dim_t n_elem_max; \ + inc_t ldx, incx; \ + inc_t ldy, incy; \ + dim_t ij0, n_shift; \ +\ + /* Set various loop parameters. */ \ + bli_set_dims_incs_uplo_2m \ + ( \ + diagoffx, diagx, transx, \ + uplox, m, n, rs_x, cs_x, rs_y, cs_y, \ + &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, &incy, &ldy, \ + &ij0, &n_shift \ + ); \ +\ + /* In the odd case where we are comparing against a complete unstored + matrix, we assert equality. Why? We assume the matrices are equal + unless we can find two corresponding elements that are unequal. So + if there are no elements, there is no inequality. Granted, this logic + is strange to think about no matter what, and thankfully it should + never be used under normal usage. */ \ + if ( bli_is_zeros( uplox_eff ) ) return TRUE; \ +\ + /* Extract the conjugation component from the transx parameter. */ \ + conjx = bli_extract_conj( transx ); \ +\ + /* Handle dense and upper/lower storage cases separately. */ \ + if ( bli_is_dense( uplox_eff ) ) \ + { \ + for ( dim_t j = 0; j < n_iter; ++j ) \ + { \ + const dim_t n_elem = n_elem_max; \ +\ + ctype* x1 = x + (j )*ldx + (0 )*incx; \ + ctype* y1 = y + (j )*ldy + (0 )*incy; \ +\ + for ( dim_t i = 0; i < n_elem; ++i ) \ + { \ + ctype* x11 = x1 + (i )*incx; \ + ctype* y11 = y1 + (i )*incy; \ + ctype x11c; \ +\ + if ( bli_is_conj( conjx ) ) { PASTEMAC(ch,copyjs)( *x11, x11c ); } \ + else { PASTEMAC(ch,copys)( *x11, x11c ); } \ +\ + if ( !PASTEMAC(ch,eq)( x11c, *y11 ) ) \ + return FALSE; \ + } \ + } \ + } \ + else \ + { \ + if ( bli_is_upper( uplox_eff ) ) \ + { \ + for ( dim_t j = 0; j < n_iter; ++j ) \ + { \ + const dim_t n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ +\ + ctype* x1 = x + (ij0+j )*ldx + (0 )*incx; \ + ctype* y1 = y + (ij0+j )*ldy + (0 )*incy; \ +\ + for ( dim_t i = 0; i < n_elem; ++i ) \ + { \ + ctype* x11 = x1 + (i )*incx; \ + ctype* y11 = y1 + (i )*incy; \ + ctype x11c; \ +\ + if ( bli_is_conj( conjx ) ) { PASTEMAC(ch,copyjs)( *x11, x11c ); } \ + else { PASTEMAC(ch,copys)( *x11, x11c ); } \ +\ + if ( !PASTEMAC(ch,eq)( x11c, *y11 ) ) \ + return FALSE; \ + } \ + } \ + } \ + else if ( bli_is_lower( uplox_eff ) ) \ + { \ + for ( dim_t j = 0; j < n_iter; ++j ) \ + { \ + const dim_t offi = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + const dim_t n_elem = n_elem_max - offi; \ +\ + ctype* x1 = x + (j )*ldx + (ij0+offi )*incx; \ + ctype* y1 = y + (j )*ldy + (ij0+offi )*incy; \ +\ + for ( dim_t i = 0; i < n_elem; ++i ) \ + { \ + ctype* x11 = x1 + (i )*incx; \ + ctype* y11 = y1 + (i )*incy; \ + ctype x11c; \ +\ + if ( bli_is_conj( conjx ) ) { PASTEMAC(ch,copyjs)( *x11, x11c ); } \ + else { PASTEMAC(ch,copys)( *x11, x11c ); } \ +\ + if ( !PASTEMAC(ch,eq)( x11c, *y11 ) ) \ + return FALSE; \ + } \ + } \ + } \ + } \ +\ + return TRUE; \ +} + +INSERT_GENTFUNC_BASIC0( eqm_unb_var1 ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + FILE* file, \ + char* s1, \ + dim_t n, \ + ctype* x, inc_t incx, \ + char* format, \ + char* s2 \ + ) \ +{ \ + dim_t i; \ + ctype* chi1; \ + char default_spec[32] = PASTEMAC(ch,formatspec)(); \ +\ + if ( format == NULL ) format = default_spec; \ +\ + chi1 = x; \ +\ + fprintf( file, "%s\n", s1 ); \ +\ + for ( i = 0; i < n; ++i ) \ + { \ + PASTEMAC(ch,fprints)( file, format, *chi1 ); \ + fprintf( file, "\n" ); \ +\ + chi1 += incx; \ + } \ +\ + fprintf( file, "%s\n", s2 ); \ +} + +INSERT_GENTFUNC_BASIC0_I( fprintv ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + FILE* file, \ + char* s1, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + char* format, \ + char* s2 \ + ) \ +{ \ + dim_t i, j; \ + ctype* chi1; \ + char default_spec[32] = PASTEMAC(ch,formatspec)(); \ +\ + if ( format == NULL ) format = default_spec; \ +\ + fprintf( file, "%s\n", s1 ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + for ( j = 0; j < n; ++j ) \ + { \ + chi1 = (( ctype* ) x) + i*rs_x + j*cs_x; \ +\ + PASTEMAC(ch,fprints)( file, format, *chi1 ); \ + fprintf( file, " " ); \ + } \ +\ + fprintf( file, "\n" ); \ + } \ +\ + fprintf( file, "%s\n", s2 ); \ + fflush( file ); \ +} + +INSERT_GENTFUNC_BASIC0_I( fprintm ) + diff --git a/frame/util/bli_util_unb_var1.h b/frame/util/bli_util_unb_var1.h index 3fb517eec9..f878488568 100644 --- a/frame/util/bli_util_unb_var1.h +++ b/frame/util/bli_util_unb_var1.h @@ -107,39 +107,6 @@ INSERT_GENTPROTR_BASIC0( normfm_unb_var1 ) INSERT_GENTPROTR_BASIC0( normim_unb_var1 ) -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ - ( \ - FILE* file, \ - char* s1, \ - dim_t n, \ - ctype* x, inc_t incx, \ - char* format, \ - char* s2 \ - ); - -INSERT_GENTPROT_BASIC0_I( fprintv ) - - -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ - ( \ - FILE* file, \ - char* s1, \ - dim_t m, \ - dim_t n, \ - ctype* x, inc_t rs_x, inc_t cs_x, \ - char* format, \ - char* s2 \ - ); - -INSERT_GENTPROT_BASIC0_I( fprintm ) - - #undef GENTPROT #define GENTPROT( ctype, ch, varname ) \ \ @@ -188,3 +155,70 @@ void PASTEMAC(ch,varname) \ INSERT_GENTPROTR_BASIC0( sumsqv_unb_var1 ) +// ----------------------------------------------------------------------------- + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +bool PASTEMAC(ch,varname) \ + ( \ + conj_t conjx, \ + dim_t n, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy \ + ); + +INSERT_GENTPROT_BASIC0( eqv_unb_var1 ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +bool PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffx, \ + diag_t diagx, \ + uplo_t uplox, \ + trans_t transx, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + ctype* y, inc_t rs_y, inc_t cs_y \ + ); + +INSERT_GENTPROT_BASIC0( eqm_unb_var1 ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + FILE* file, \ + char* s1, \ + dim_t n, \ + ctype* x, inc_t incx, \ + char* format, \ + char* s2 \ + ); + +INSERT_GENTPROT_BASIC0_I( fprintv ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ + ( \ + FILE* file, \ + char* s1, \ + dim_t m, \ + dim_t n, \ + ctype* x, inc_t rs_x, inc_t cs_x, \ + char* format, \ + char* s2 \ + ); + +INSERT_GENTPROT_BASIC0_I( fprintm ) + + diff --git a/kernels/armsve/1m/armsve512_asm_transpose_d8x2.h b/kernels/armsve/1m/armsve512_asm_transpose_d8x2.h new file mode 100644 index 0000000000..31dd5704ab --- /dev/null +++ b/kernels/armsve/1m/armsve512_asm_transpose_d8x2.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#define SVE512_IN_REG_TRANSPOSE_d8x2(DST0,DST1,DST2,DST3,DST4,DST5,DST6SRC0,DST7SRC1,PT,P2C,P4C,P6C) \ + "trn1 " #DST0".d, " #DST6SRC0".d, " #DST7SRC1".d \n\t" \ + "trn2 " #DST1".d, " #DST6SRC0".d, " #DST7SRC1".d \n\t" \ + "compact " #DST2".d, " #P2C", " #DST0".d \n\t" \ + "compact " #DST3".d, " #P2C", " #DST1".d \n\t" \ + "compact " #DST4".d, " #P4C", " #DST0".d \n\t" \ + "compact " #DST5".d, " #P4C", " #DST1".d \n\t" \ + "compact " #DST6SRC0".d, " #P6C", " #DST0".d \n\t" \ + "compact " #DST7SRC1".d, " #P6C", " #DST1".d \n\t" + diff --git a/kernels/armsve/1m/armsve512_asm_transpose_d8x8.h b/kernels/armsve/1m/armsve512_asm_transpose_d8x8.h new file mode 100644 index 0000000000..98426c9476 --- /dev/null +++ b/kernels/armsve/1m/armsve512_asm_transpose_d8x8.h @@ -0,0 +1,97 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#define SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(XTMP,PT,P2C,P4C,P6C,PTFTF,P4,P6) \ + "ptrue " #PT".d \n\t" \ + "mov " #XTMP", #2 \n\t" \ + "whilelo " #P2C".d, xzr, " #XTMP" \n\t" \ + "mov " #XTMP", #4 \n\t" \ + "whilelo " #P4".d, xzr, " #XTMP" \n\t" \ + "mov " #XTMP", #6 \n\t" \ + "whilelo " #P6".d, xzr, " #XTMP" \n\t" \ + \ + "eor " #PTFTF".b, " #PT"/z, " #P6".b, " #P4".b \n\t" /***** o o | o */ \ + "orr " #PTFTF".b, " #PT"/z, " #PTFTF".b, " #P2C".b \n\t" /* | o | o */ \ + \ + "not " #P2C".b, " #PT"/z, " #P2C".b \n\t" \ + "not " #P4C".b, " #PT"/z, " #P4".b \n\t" \ + "not " #P6C".b, " #PT"/z, " #P6".b \n\t" \ + +#define SVE512_IN_REG_TRANSPOSE_d8x8(DST0,DST1,DST2,DST3,DST4,DST5,DST6,DST7,SRC0,SRC1,SRC2,SRC3,SRC4,SRC5,SRC6,SRC7,PT,P2C,P4C,P6C,PTFTF,P4,P6) \ + "trn1 " #DST0".d, " #SRC0".d, " #SRC1".d \n\t" \ + "trn2 " #DST1".d, " #SRC0".d, " #SRC1".d \n\t" \ + "trn1 " #DST2".d, " #SRC2".d, " #SRC3".d \n\t" \ + "trn2 " #DST3".d, " #SRC2".d, " #SRC3".d \n\t" \ + "trn1 " #DST4".d, " #SRC4".d, " #SRC5".d \n\t" \ + "trn2 " #DST5".d, " #SRC4".d, " #SRC5".d \n\t" \ + "trn1 " #DST6".d, " #SRC6".d, " #SRC7".d \n\t" \ + "trn2 " #DST7".d, " #SRC6".d, " #SRC7".d \n\t" \ + \ + "compact " #SRC0".d, " #P2C", " #DST0".d \n\t" \ + "compact " #SRC2".d, " #P2C", " #DST1".d \n\t" \ + "ext " #SRC1".b, " #SRC1".b, " #DST2".b, #48 \n\t" \ + "ext " #SRC3".b, " #SRC3".b, " #DST3".b, #48 \n\t" \ + "compact " #SRC4".d, " #P2C", " #DST4".d \n\t" \ + "compact " #SRC6".d, " #P2C", " #DST5".d \n\t" \ + "ext " #SRC5".b, " #SRC5".b, " #DST6".b, #48 \n\t" \ + "ext " #SRC7".b, " #SRC7".b, " #DST7".b, #48 \n\t" \ + \ + "sel " #DST0".d, " #PTFTF", " #DST0".d, " #SRC1".d \n\t" \ + "sel " #DST2".d, " #PTFTF", " #SRC0".d, " #DST2".d \n\t" \ + "sel " #DST1".d, " #PTFTF", " #DST1".d, " #SRC3".d \n\t" \ + "sel " #DST3".d, " #PTFTF", " #SRC2".d, " #DST3".d \n\t" \ + "sel " #DST4".d, " #PTFTF", " #DST4".d, " #SRC5".d \n\t" \ + "sel " #DST6".d, " #PTFTF", " #SRC4".d, " #DST6".d \n\t" \ + "sel " #DST5".d, " #PTFTF", " #DST5".d, " #SRC7".d \n\t" \ + "sel " #DST7".d, " #PTFTF", " #SRC6".d, " #DST7".d \n\t" \ + \ + "compact " #SRC0".d, " #P4C", " #DST0".d \n\t" \ + "compact " #SRC1".d, " #P4C", " #DST1".d \n\t" \ + "compact " #SRC2".d, " #P4C", " #DST2".d \n\t" \ + "compact " #SRC3".d, " #P4C", " #DST3".d \n\t" \ + "ext " #SRC4".b, " #SRC4".b, " #DST4".b, #32 \n\t" \ + "ext " #SRC5".b, " #SRC5".b, " #DST5".b, #32 \n\t" \ + "ext " #SRC6".b, " #SRC6".b, " #DST6".b, #32 \n\t" \ + "ext " #SRC7".b, " #SRC7".b, " #DST7".b, #32 \n\t" \ + \ + "sel " #DST0".d, " #P4", " #DST0".d, " #SRC4".d \n\t" \ + "sel " #DST1".d, " #P4", " #DST1".d, " #SRC5".d \n\t" \ + "sel " #DST2".d, " #P4", " #DST2".d, " #SRC6".d \n\t" \ + "sel " #DST3".d, " #P4", " #DST3".d, " #SRC7".d \n\t" \ + "sel " #DST4".d, " #P4", " #SRC0".d, " #DST4".d \n\t" \ + "sel " #DST5".d, " #P4", " #SRC1".d, " #DST5".d \n\t" \ + "sel " #DST6".d, " #P4", " #SRC2".d, " #DST6".d \n\t" \ + "sel " #DST7".d, " #P4", " #SRC3".d, " #DST7".d \n\t" + diff --git a/kernels/armsve/1m/bli_dpackm_armsve256_int_8xk.c b/kernels/armsve/1m/bli_dpackm_armsve256_int_8xk.c new file mode 100644 index 0000000000..7171347bf1 --- /dev/null +++ b/kernels/armsve/1m/bli_dpackm_armsve256_int_8xk.c @@ -0,0 +1,231 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#if !defined(BLIS_FAMILY_A64FX) +#include + +// assumption: +// SVE vector length = 256 bits. +// + +void bli_dpackm_armsve256_int_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 8; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + + double* restrict alpha1 = a; + double* restrict alpha1_4 = alpha1 + 4 * inca; + double* restrict pi1 = p; + const svbool_t all_active = svptrue_b64(); + svfloat64_t z_a0; + svfloat64_t z_a4; + svuint64_t z_index; + + // creating index for gather/scatter + // with each element as: 0, 1*inca, 2*inca, 3*inca + z_index = svindex_u64( 0, inca * sizeof( double ) ); + + if ( cdim == mnr ) + { + if ( bli_deq1( *kappa ) ) + { + if ( inca == 1 ) // continous memory. packA style + { + for ( dim_t k = n; k != 0; --k ) + { + // svld1_f64 retrieves all zero's into z_a0 and z_a4, + // which is not correct. + // qemu-aarch64 or gcc interpretation of svld1_f64 + // should be blamed. + + // load 8 continuous elments from *a + // z_a0 = svld1_f64( all_active, alpha1 ); + // z_a4 = svld1_vnum_f64( all_active, alpha1, 1 ); + + // as a workaround, using gather load + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a4 = svld1_gather_u64offset_f64( all_active, alpha1_4, z_index ); + + // store them into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a4 ); + + alpha1 += lda; + alpha1_4 = alpha1 + 4 * inca; + pi1 += ldp; + } + } + else // gather/scatter load/store. packB style + { + for ( dim_t k = n; k != 0; --k ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a4 = svld1_gather_u64offset_f64( all_active, alpha1_4, z_index ); + + // scatter store into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a4 ); + + alpha1 += lda; + alpha1_4 = alpha1 + 4 * inca; + pi1 += ldp; + } + } + } + else // *kappa != 1.0 + { + // load kappa into vector + svfloat64_t z_kappa; + + z_kappa = svdup_f64( *kappa ); + + if ( inca == 1 ) // continous memory. packA style + { + for ( dim_t k = n; k != 0; --k ) + { + // load 8 continuous elments from *a + // z_a0 = svld1_f64( all_active, alpha1 ); + // z_a4 = svld1_vnum_f64( all_active, alpha1, 1 ); + // same reason as above. as a workaround, using gather load + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a4 = svld1_gather_u64offset_f64( all_active, alpha1_4, z_index ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a4 = svmul_lane_f64( z_a4, z_kappa, 0 ); + + // store them into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a4 ); + + alpha1 += lda; + alpha1_4 = alpha1 + 4 * inca; + pi1 += ldp; + } + } + else // gather/scatter load/store. packB style + { + for ( dim_t k = n; k != 0; --k ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a4 = svld1_gather_u64offset_f64( all_active, alpha1_4, z_index ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a4 = svmul_lane_f64( z_a4, z_kappa, 0 ); + + // scatter store into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a4 ); + + alpha1 += lda; + alpha1_4 = alpha1 + 4 * inca; + pi1 += ldp; + } + } + } // end of if ( *kappa == 1.0 ) + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + +#endif // __has_include() diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c new file mode 100644 index 0000000000..a086b3a76e --- /dev/null +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c @@ -0,0 +1,365 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "armsve512_asm_transpose_d8x8.h" +#include "armsve512_asm_transpose_d8x2.h" +#include "../3/armsve_asm_macros.h" + +// assumption: +// SVE vector length = 512 bits. + +void bli_dpackm_armsve512_asm_10xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 10; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + const bool gs = inca != 1 && lda != 1; + const bool unitk = bli_deq1( *kappa ); + +#ifdef _A64FX + { + // Infer whether A or B is being packed. + if ( schema == BLIS_PACKED_ROWS ) + p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p; + if ( schema == BLIS_PACKED_COLUMNS ) + p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p; + } +#endif + + if ( cdim == mnr && !gs && unitk ) + { + uint64_t n_mker = n / 8; + uint64_t n_left = n % 8; + __asm__ volatile ( + "mov x0, %[a] \n\t" + "mov x1, %[p] \n\t" + "mov x2, %[ldp] \n\t" + "mov x3, %[lda] \n\t" + "mov x4, %[inca] \n\t" + "cmp x4, #1 \n\t" + // Skips by sizeof(double). + "mov x8, #8 \n\t" + "madd x2, x2, x8, xzr \n\t" + "madd x3, x3, x8, xzr \n\t" + "madd x4, x4, x8, xzr \n\t" + // Loop constants. + "mov x8, %[n_mker] \n\t" + "mov x9, %[n_left] \n\t" + "ptrue p0.d \n\t" + BNE(AROWSTOR) + // A stored in columns. + LABEL(ACOLSTOR) + // Prefetch distance. + "mov x17, #8 \n\t" + "madd x17, x17, x3, xzr \n\t" +#ifdef _A64FX + // Disable hardware prefetch for A. + "mov x16, 0x6 \n\t" + "lsl x16, x16, #60 \n\t" + "orr x0, x0, x16 \n\t" +#endif + LABEL(ACOLSTORMKER) + "cmp x8, xzr \n\t" + BEQ(ACOLSTORMKEREND) + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ldr q1, [x0, #64] \n\t" + "ld1d z2.d, p0/z, [x5] \n\t" + "ldr q3, [x5, #64] \n\t" + "ld1d z4.d, p0/z, [x6] \n\t" + "ldr q5, [x6, #64] \n\t" + "ld1d z6.d, p0/z, [x7] \n\t" + "ldr q7, [x7, #64] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x0, x7, x3 \n\t" + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "ld1d z8.d, p0/z, [x0] \n\t" + "ldr q9, [x0, #64] \n\t" + "ld1d z10.d, p0/z, [x5] \n\t" + "ldr q11, [x5, #64] \n\t" + "ld1d z12.d, p0/z, [x6] \n\t" + "ldr q13, [x6, #64] \n\t" + "ld1d z14.d, p0/z, [x7] \n\t" + "ldr q15, [x7, #64] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + // Plain storage + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "st1d z0.d, p0, [x1] \n\t" + "str q1, [x1, #64] \n\t" + "st1d z2.d, p0, [x10] \n\t" + "str q3, [x10, #64] \n\t" + "st1d z4.d, p0, [x11] \n\t" + "str q5, [x11, #64] \n\t" + "st1d z6.d, p0, [x12] \n\t" + "str q7, [x12, #64] \n\t" + "st1d z8.d, p0, [x13] \n\t" + "str q9, [x13, #64] \n\t" + "st1d z10.d, p0, [x14] \n\t" + "str q11, [x14, #64] \n\t" + "st1d z12.d, p0, [x15] \n\t" + "str q13, [x15, #64] \n\t" + "st1d z14.d, p0, [x16] \n\t" + "str q15, [x16, #64] \n\t" + "add x1, x16, x2 \n\t" + // Realign and store. + // "ext z1.b, z1.b, z1.b, #16 \n\t" + // "ext z1.b, z1.b, z2.b, #48 \n\t" + // "ext z2.b, z2.b, z3.b, #16 \n\t" + // "ext z2.b, z2.b, z4.b, #32 \n\t" + // "ext z4.b, z4.b, z5.b, #16 \n\t" + // "ext z4.b, z4.b, z6.b, #16 \n\t" + // "ext z6.b, z6.b, z7.b, #16 \n\t" + // "ext z9.b, z9.b, z9.b, #16 \n\t" + // "ext z9.b, z9.b, z10.b, #48 \n\t" + // "ext z10.b, z10.b, z11.b, #16 \n\t" + // "ext z10.b, z10.b, z12.b, #32 \n\t" + // "ext z12.b, z12.b, z13.b, #16 \n\t" + // "ext z12.b, z12.b, z14.b, #16 \n\t" + // "ext z14.b, z14.b, z15.b, #16 \n\t" + // "st1d z0.d, p0, [x1] \n\t" + // "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + // "st1d z2.d, p0, [x1, #2, mul vl] \n\t" + // "st1d z4.d, p0, [x1, #3, mul vl] \n\t" + // "st1d z6.d, p0, [x1, #4, mul vl] \n\t" + // "add x1, x1, #320 \n\t" + // "st1d z8.d, p0, [x1] \n\t" + // "st1d z9.d, p0, [x1, #1, mul vl] \n\t" + // "st1d z10.d, p0, [x1, #2, mul vl] \n\t" + // "st1d z12.d, p0, [x1, #3, mul vl] \n\t" + // "st1d z14.d, p0, [x1, #4, mul vl] \n\t" + // "add x1, x1, #320 \n\t" + "add x0, x7, x3 \n\t" + "sub x8, x8, #1 \n\t" + BRANCH(ACOLSTORMKER) + LABEL(ACOLSTORMKEREND) + LABEL(ACOLSTORLEFT) + "cmp x9, xzr \n\t" + BEQ(UNITKDONE) + "ld1d z0.d, p0/z, [x0] \n\t" + "ldr q1, [x0, #64] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "str q1, [x1, #64] \n\t" + "add x0, x0, x3 \n\t" + "add x1, x1, x2 \n\t" + "sub x9, x9, #1 \n\t" + BRANCH(ACOLSTORLEFT) + // A stored in rows. + LABEL(AROWSTOR) + // Prepare predicates for in-reg transpose. + SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(x16,p0,p1,p2,p3,p8,p4,p6) + LABEL(AROWSTORMKER) // X[10-16] for A here not P. Be careful. + "cmp x8, xzr \n\t" + BEQ(AROWSTORMKEREND) + "add x10, x0, x4 \n\t" + "add x11, x10, x4 \n\t" + "add x12, x11, x4 \n\t" + "add x13, x12, x4 \n\t" + "add x14, x13, x4 \n\t" + "add x15, x14, x4 \n\t" + "add x16, x15, x4 \n\t" + "add x17, x16, x4 \n\t" + "add x18, x17, x4 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x10] \n\t" + "ld1d z2.d, p0/z, [x11] \n\t" + "ld1d z3.d, p0/z, [x12] \n\t" + "ld1d z4.d, p0/z, [x13] \n\t" + "ld1d z5.d, p0/z, [x14] \n\t" + "ld1d z6.d, p0/z, [x15] \n\t" + "ld1d z7.d, p0/z, [x16] \n\t" + "ld1d z22.d, p0/z, [x17] \n\t" + "ld1d z23.d, p0/z, [x18] \n\t" + // Transpose first 8 rows. + SVE512_IN_REG_TRANSPOSE_d8x8(z8,z9,z10,z11,z12,z13,z14,z15,z0,z1,z2,z3,z4,z5,z6,z7,p0,p1,p2,p3,p8,p4,p6) + // Transpose last 2 rows. + SVE512_IN_REG_TRANSPOSE_d8x2(z16,z17,z18,z19,z20,z21,z22,z23,p0,p1,p2,p3) + // Plain storage. + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "st1d z8.d, p0, [x1] \n\t" + "str q16, [x1, #64] \n\t" + "st1d z9.d, p0, [x10] \n\t" + "str q17, [x10, #64] \n\t" + "st1d z10.d, p0, [x11] \n\t" + "str q18, [x11, #64] \n\t" + "st1d z11.d, p0, [x12] \n\t" + "str q19, [x12, #64] \n\t" + "st1d z12.d, p0, [x13] \n\t" + "str q20, [x13, #64] \n\t" + "st1d z13.d, p0, [x14] \n\t" + "str q21, [x14, #64] \n\t" + "st1d z14.d, p0, [x15] \n\t" + "str q22, [x15, #64] \n\t" + "st1d z15.d, p0, [x16] \n\t" + "str q23, [x16, #64] \n\t" + "add x1, x16, x2 \n\t" + "add x0, x0, #64 \n\t" + "sub x8, x8, #1 \n\t" + BRANCH(AROWSTORMKER) + LABEL(AROWSTORMKEREND) + "mov x4, %[inca] \n\t" // Restore unshifted inca. + "index z30.d, xzr, x4 \n\t" // Generate index. + "lsl x4, x4, #3 \n\t" // Shift again. + "lsl x5, x4, #3 \n\t" // Virtual column vl. + LABEL(AROWSTORLEFT) + "cmp x9, xzr \n\t" + BEQ(UNITKDONE) + "add x6, x0, x5 \n\t" + "add x7, x6, x4 \n\t" + "ld1d z0.d, p0/z, [x0, z30.d, lsl #3] \n\t" + "ldr d1, [x6] \n\t" + "ldr d2, [x7] \n\t" + "trn1 v1.2d, v1.2d, v2.2d \n\t" + "st1d z0.d, p0, [x1] \n\t" + "str q1, [x1, #64] \n\t" + "add x1, x1, x2 \n\t" + "add x0, x0, #8 \n\t" + "sub x9, x9, #1 \n\t" + BRANCH(AROWSTORLEFT) + LABEL(UNITKDONE) + "mov x0, #0 \n\t" + : + : [a] "r" (a), + [p] "r" (p), + [lda] "r" (lda), + [ldp] "r" (ldp), + [inca] "r" (inca), + [n_mker] "r" (n_mker), + [n_left] "r" (n_left) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14","x15", + "x16","x17","x18", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19","z20","z21","z22","z23", + // "z24","z25","z26","z27","z28","z29", + "z30","z31", + "p0", "p1", "p2", "p3", "p4", // "p5", + "p6", "p7", "p8" + ); + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c new file mode 100644 index 0000000000..aeb323c0ca --- /dev/null +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c @@ -0,0 +1,363 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "armsve512_asm_transpose_d8x8.h" +#include "../3/armsve_asm_macros.h" + +// assumption: +// SVE vector length = 512 bits. + +void bli_dpackm_armsve512_asm_16xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 16; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + const bool gs = inca != 1 && lda != 1; + const bool unitk = bli_deq1( *kappa ); + +#ifdef _A64FX + { + // Infer whether A or B is being packed. + if ( schema == BLIS_PACKED_ROWS ) + p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p; + if ( schema == BLIS_PACKED_COLUMNS ) + p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p; + } +#endif + + if ( cdim == mnr && !gs && unitk ) + { + uint64_t n_mker = n / 8; + uint64_t n_left = n % 8; + __asm__ volatile ( + "mov x0, %[a] \n\t" + "mov x1, %[p] \n\t" + "mov x2, %[ldp] \n\t" + "mov x3, %[lda] \n\t" + "mov x4, %[inca] \n\t" + "cmp x4, #1 \n\t" + // Skips by sizeof(double). + "mov x8, #8 \n\t" + "madd x2, x2, x8, xzr \n\t" + "madd x3, x3, x8, xzr \n\t" + "madd x4, x4, x8, xzr \n\t" + + // "mov x8, 0x8 \n\t" // Control#0 for A address. + // "mov x8, 0x24 \n\t" // Higher 6bit for Control#0: + // "lsl x8, x8, #58 \n\t" // Valid|Strong|Strong|Alloc|Load|Strong + // "orr x8, x8, x3 \n\t" // Stride. + // "msr S3_3_C11_C6_0, x8 \n\t" // Write system register. + + // Loop constants. + "mov x8, %[n_mker] \n\t" + "mov x9, %[n_left] \n\t" + "ptrue p0.d \n\t" + BNE(AROWSTOR) + // A stored in columns. + LABEL(ACOLSTOR) + // Prefetch distance. + "mov x17, #8 \n\t" + "madd x17, x17, x3, xzr \n\t" +#ifdef _A64FX + "mov x16, 0x6 \n\t" // Disable hardware prefetch for A. + "lsl x16, x16, #60 \n\t" + "orr x0, x0, x16 \n\t" +#endif + // "add x5, x0, x3 \n\t" + // "add x6, x5, x3 \n\t" + // "add x7, x6, x3 \n\t" + // "prfm PLDL1STRM, [x0] \n\t" + // "prfm PLDL1STRM, [x5] \n\t" + // "prfm PLDL1STRM, [x6] \n\t" + // "prfm PLDL1STRM, [x7] \n\t" + // "add x18, x7, x3 \n\t" + // "add x5, x18, x3 \n\t" + // "add x6, x5, x3 \n\t" + // "add x7, x6, x3 \n\t" + // "prfm PLDL1STRM, [x18] \n\t" + // "prfm PLDL1STRM, [x5] \n\t" + // "prfm PLDL1STRM, [x6] \n\t" + // "prfm PLDL1STRM, [x7] \n\t" + LABEL(ACOLSTORMKER) + "cmp x8, xzr \n\t" + BEQ(ACOLSTORMKEREND) + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x0, #1, mul vl] \n\t" + "ld1d z2.d, p0/z, [x5] \n\t" + "ld1d z3.d, p0/z, [x5, #1, mul vl] \n\t" + "ld1d z4.d, p0/z, [x6] \n\t" + "ld1d z5.d, p0/z, [x6, #1, mul vl] \n\t" + "ld1d z6.d, p0/z, [x7] \n\t" + "ld1d z7.d, p0/z, [x7, #1, mul vl] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x0, x7, x3 \n\t" + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "ld1d z8.d, p0/z, [x0] \n\t" + "ld1d z9.d, p0/z, [x0, #1, mul vl] \n\t" + "ld1d z10.d, p0/z, [x5] \n\t" + "ld1d z11.d, p0/z, [x5, #1, mul vl] \n\t" + "ld1d z12.d, p0/z, [x6] \n\t" + "ld1d z13.d, p0/z, [x6, #1, mul vl] \n\t" + "ld1d z14.d, p0/z, [x7] \n\t" + "ld1d z15.d, p0/z, [x7, #1, mul vl] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + "st1d z2.d, p0, [x10] \n\t" + "st1d z3.d, p0, [x10, #1, mul vl] \n\t" + "st1d z4.d, p0, [x11] \n\t" + "st1d z5.d, p0, [x11, #1, mul vl] \n\t" + "st1d z6.d, p0, [x12] \n\t" + "st1d z7.d, p0, [x12, #1, mul vl] \n\t" + "st1d z8.d, p0, [x13] \n\t" + "st1d z9.d, p0, [x13, #1, mul vl] \n\t" + "st1d z10.d, p0, [x14] \n\t" + "st1d z11.d, p0, [x14, #1, mul vl] \n\t" + "st1d z12.d, p0, [x15] \n\t" + "st1d z13.d, p0, [x15, #1, mul vl] \n\t" + "st1d z14.d, p0, [x16] \n\t" + "st1d z15.d, p0, [x16, #1, mul vl] \n\t" + "add x0, x7, x3 \n\t" + "add x1, x16, x2 \n\t" + "sub x8, x8, #1 \n\t" + BRANCH(ACOLSTORMKER) + LABEL(ACOLSTORMKEREND) + LABEL(ACOLSTORLEFT) + "cmp x9, xzr \n\t" + BEQ(UNITKDONE) + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x0, #1, mul vl] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + "add x0, x0, x3 \n\t" + "add x1, x1, x2 \n\t" + "sub x9, x9, #1 \n\t" + BRANCH(ACOLSTORLEFT) + // A stored in rows. + LABEL(AROWSTOR) + // Prepare predicates for in-reg transpose. + SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(x16,p0,p1,p2,p3,p8,p4,p6) + LABEL(AROWSTORMKER) // X[10-16] for A here not P. Be careful. + "cmp x8, xzr \n\t" + BEQ(AROWSTORMKEREND) + "add x10, x0, x4 \n\t" + "add x11, x10, x4 \n\t" + "add x12, x11, x4 \n\t" + "add x13, x12, x4 \n\t" + "add x14, x13, x4 \n\t" + "add x15, x14, x4 \n\t" + "add x16, x15, x4 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x10] \n\t" + "ld1d z2.d, p0/z, [x11] \n\t" + "ld1d z3.d, p0/z, [x12] \n\t" + "ld1d z4.d, p0/z, [x13] \n\t" + "ld1d z5.d, p0/z, [x14] \n\t" + "ld1d z6.d, p0/z, [x15] \n\t" + "ld1d z7.d, p0/z, [x16] \n\t" + "add x5, x16, x4 \n\t" + "add x10, x5, x4 \n\t" + "add x11, x10, x4 \n\t" + "add x12, x11, x4 \n\t" + "add x13, x12, x4 \n\t" + "add x14, x13, x4 \n\t" + "add x15, x14, x4 \n\t" + "add x16, x15, x4 \n\t" + "ld1d z16.d, p0/z, [x5] \n\t" + "ld1d z17.d, p0/z, [x10] \n\t" + "ld1d z18.d, p0/z, [x11] \n\t" + "ld1d z19.d, p0/z, [x12] \n\t" + "ld1d z20.d, p0/z, [x13] \n\t" + "ld1d z21.d, p0/z, [x14] \n\t" + "ld1d z22.d, p0/z, [x15] \n\t" + "ld1d z23.d, p0/z, [x16] \n\t" + // Transpose first 8 rows. + SVE512_IN_REG_TRANSPOSE_d8x8(z8,z9,z10,z11,z12,z13,z14,z15,z0,z1,z2,z3,z4,z5,z6,z7,p0,p1,p2,p3,p8,p4,p6) + // Transpose last 8 rows. + SVE512_IN_REG_TRANSPOSE_d8x8(z24,z25,z26,z27,z28,z29,z30,z31,z16,z17,z18,z19,z20,z21,z22,z23,p0,p1,p2,p3,p8,p4,p6) + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "st1d z8.d, p0, [x1] \n\t" + "st1d z24.d, p0, [x1, #1, mul vl] \n\t" + "st1d z9.d, p0, [x10] \n\t" + "st1d z25.d, p0, [x10, #1, mul vl] \n\t" + "st1d z10.d, p0, [x11] \n\t" + "st1d z26.d, p0, [x11, #1, mul vl] \n\t" + "st1d z11.d, p0, [x12] \n\t" + "st1d z27.d, p0, [x12, #1, mul vl] \n\t" + "st1d z12.d, p0, [x13] \n\t" + "st1d z28.d, p0, [x13, #1, mul vl] \n\t" + "st1d z13.d, p0, [x14] \n\t" + "st1d z29.d, p0, [x14, #1, mul vl] \n\t" + "st1d z14.d, p0, [x15] \n\t" + "st1d z30.d, p0, [x15, #1, mul vl] \n\t" + "st1d z15.d, p0, [x16] \n\t" + "st1d z31.d, p0, [x16, #1, mul vl] \n\t" + "add x0, x0, #64 \n\t" + "add x1, x16, x2 \n\t" + "sub x8, x8, #1 \n\t" + BRANCH(AROWSTORMKER) + LABEL(AROWSTORMKEREND) + "mov x4, %[inca] \n\t" // Restore unshifted inca. + "index z30.d, xzr, x4 \n\t" // Generate index. + "lsl x4, x4, #3 \n\t" // Shift again. + "lsl x5, x4, #3 \n\t" // Virtual column vl. + LABEL(AROWSTORLEFT) + "cmp x9, xzr \n\t" + BEQ(UNITKDONE) + "add x6, x0, x5 \n\t" + "ld1d z0.d, p0/z, [x0, z30.d, lsl #3] \n\t" + "ld1d z1.d, p0/z, [x6, z30.d, lsl #3] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + "add x1, x1, x2 \n\t" + "add x0, x0, #8 \n\t" + "sub x9, x9, #1 \n\t" + BRANCH(AROWSTORLEFT) + LABEL(UNITKDONE) + "mov x0, #0 \n\t" + : + : [a] "r" (a), + [p] "r" (p), + [lda] "r" (lda), + [ldp] "r" (ldp), + [inca] "r" (inca), + [n_mker] "r" (n_mker), + [n_left] "r" (n_left) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14","x15", + "x16","x17","x18", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10","z11","z12","z13","z14","z15", + // "z16","z17","z18","z19","z20","z21","z22","z23", + // "z24","z25","z26","z27","z28","z29","z30","z31", + "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7" + ); + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} diff --git a/kernels/armsve/1m/old/bli_dpackm_armsve512_int_12xk.c b/kernels/armsve/1m/old/bli_dpackm_armsve512_int_12xk.c new file mode 100644 index 0000000000..47b15b4375 --- /dev/null +++ b/kernels/armsve/1m/old/bli_dpackm_armsve512_int_12xk.c @@ -0,0 +1,358 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if !defined(BLIS_FAMILY_A64FX) +#include + +// assumption: +// SVE vector length = 512 bits. +// TODO: +// 2-rows -> 3 vectors packing and use predicator only in odd num of rows to be packed. +// prefetching is needed. + +void bli_dpackm_armsve512_int_12xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 12; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + + double* restrict alpha1 = a; + double* restrict alpha1_8 = alpha1 + 8 * inca; + double* restrict alpha1_p4 = alpha1 + 4 * inca; + double* restrict alpha1_m4 = alpha1 - 4 * inca; + double* restrict pi1 = p; + const svbool_t all_active = svptrue_b64(); + const svbool_t first_half_active = svwhilelt_b64(0, 4); + const svbool_t last_half_active = svnot_z(all_active, first_half_active); + svfloat64_t z_a0; + svfloat64_t z_a8; + svfloat64_t z_a8_lh; + svfloat64_t z_a16; + svuint64_t z_index; + + // creating index for gather/scatter + // with each element as: 0, 1*inca, 2*inca, 3*inca + z_index = svindex_u64( 0, inca * sizeof( double ) ); + + if ( cdim == mnr ) + { + if ( bli_deq1( *kappa ) ) + { + if ( inca == 1 ) // continous memory. packA style + { + dim_t k = n; + // 2 pack into 3 case. + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // load 12 continuous elments from *a, filling last half of z8. + z_a8_lh = svld1_f64( last_half_active, alpha1_m4 ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_f64( all_active, alpha1_p4 ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + // line-by-line packing case. + for ( ; k != 0; --k ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // store them into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + else // gather/scatter load/store. packB style + { + dim_t k = n; + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // gather load from *a, filling last half of z8. + z_a8_lh = svld1_gather_u64offset_f64( last_half_active, alpha1_m4, z_index ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_gather_u64offset_f64( all_active, alpha1_p4, z_index ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + for ( ; k != 0; --k ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // scatter store into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + } + else // *kappa != 1.0 + { + // load kappa into vector + svfloat64_t z_kappa; + + z_kappa = svdup_f64( *kappa ); + + if ( inca == 1 ) // continous memory. packA style + { + dim_t k = n; + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // load 12 continuous elments from *a, filling last half of z8. + z_a8_lh = svld1_f64( last_half_active, alpha1_m4 ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_f64( all_active, alpha1_p4 ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + z_a16 = svmul_lane_f64( z_a16, z_kappa, 0 ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + for ( ; k != 0; --k ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + + // store them into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + else // gather/scatter load/store. packB style + { + dim_t k = n; + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // gather load from *a, filling last half of z8. + z_a8_lh = svld1_gather_u64offset_f64( last_half_active, alpha1_m4, z_index ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_gather_u64offset_f64( all_active, alpha1_p4, z_index ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + z_a16 = svmul_lane_f64( z_a16, z_kappa, 0 ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + for ( ; k != 0; --k ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + + // scatter store into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + } // end of if ( *kappa == 1.0 ) + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + +#endif // __has_include() diff --git a/kernels/armsve/3/armsve_asm_2vx10.h b/kernels/armsve/3/armsve_asm_2vx10.h new file mode 100644 index 0000000000..ae89fa1ece --- /dev/null +++ b/kernels/armsve/3/armsve_asm_2vx10.h @@ -0,0 +1,198 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_FMLA2_LD1R(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV0,BADDR,8) \ + GEMM_FMLA2_LD1R(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV1,BADDR,9) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ + GEMM_FMLA2_LD1R(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV2,BADDR,0) \ + GEMM_FMLA2_LD1R(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV3,BADDR,1) \ + GEMM_FMLA2_LD1R(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV4,BADDR,2) \ + GEMM_FMLA2_LD1R(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV5,BADDR,3) \ + GEMM_FMLA2_LD1R(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV6,BADDR,4) \ + GEMM_FMLA2_LD1R(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV7,BADDR,5) \ + \ + GEMM_FMLA2_LD1R(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV0,BADDR,6) \ + GEMM_FMLA2_LD1R(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV1,BADDR,7) + +// Second through forth microkernels are the first one with B vectors rotated. +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_2(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV2,BV3,BV4,BV5,BV6,BV7,BV0,BV1,BADDR,BRSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_3(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BADDR,BRSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_4(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV6,BV7,BV0,BV1,BV2,BV3,BV4,BV5,BADDR,BRSBIT) +// NOTE: +// The microkernel (PLAIN_1-4 as a whole) satisfies on entry/exit +// (sth. akin to loop-invariant): +// - BV[0-7] holds B[0:7, 4*k_cur] +// - B's address stops at B[0, 4*k_cur+1] + +// Final loop inside K=4 microkernels. +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_FMLA2_LD1R(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV6,BADDR,8) \ + GEMM_FMLA2_LD1R(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV7,BADDR,9) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ + GEMM_FMLA2(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV0) \ + GEMM_FMLA2(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV1) \ + GEMM_FMLA2(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV2) \ + GEMM_FMLA2(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV3) \ + GEMM_FMLA2(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV4) \ + GEMM_FMLA2(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV5) \ + GEMM_FMLA2(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV6) \ + GEMM_FMLA2(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV7) + +// K=4 MKer loop with B memory scattered. +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV0,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV1,BELMADDR,BCSBIT) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ +" mov "#BELMADDR", "#BADDR" \n\t" \ + GEMM_FMLA2_LD1R_G_ELMFWD(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV2,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV3,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV4,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV5,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV6,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV7,BELMADDR,BCSBIT) \ + \ + GEMM_FMLA2_LD1R_G_ELMFWD(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV0,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV1,BELMADDR,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_2(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV2,BV3,BV4,BV5,BV6,BV7,BV0,BV1,BADDR,BELMADDR,BRSBIT,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_3(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BADDR,BELMADDR,BRSBIT,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_4(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV6,BV7,BV0,BV1,BV2,BV3,BV4,BV5,BADDR,BELMADDR,BRSBIT,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_4_RESIDUAL(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV6,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV7,BELMADDR,BCSBIT) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ +" mov "#BELMADDR", "#BADDR" \n\t" \ + GEMM_FMLA2(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV0) \ + GEMM_FMLA2(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV1) \ + GEMM_FMLA2(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV2) \ + GEMM_FMLA2(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV3) \ + GEMM_FMLA2(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV4) \ + GEMM_FMLA2(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV5) \ + GEMM_FMLA2(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV6) \ + GEMM_FMLA2(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV7) + + +#define CLEAR_COL20(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15,Z16,Z17,Z18,Z19) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL4(Z12,Z13,Z14,Z15) \ + CLEAR_COL4(Z16,Z17,Z18,Z19) + +#define SCALE_COL20(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15,Z16,Z17,Z18,Z19,ZFACTOR) \ + SCALE_COL4(Z00,Z01,Z02,Z03,ZFACTOR) \ + SCALE_COL4(Z04,Z05,Z06,Z07,ZFACTOR) \ + SCALE_COL4(Z08,Z09,Z10,Z11,ZFACTOR) \ + SCALE_COL4(Z12,Z13,Z14,Z15,ZFACTOR) \ + SCALE_COL4(Z16,Z17,Z18,Z19,ZFACTOR) + +#define GEMM_C_FMLA_UKER(C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,PT,Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,ZSCALE) \ + GEMM_FMLA2(C0FH,C0LH,PT,Z0FH,Z0LH,ZSCALE) \ + GEMM_FMLA2(C1FH,C1LH,PT,Z1FH,Z1LH,ZSCALE) \ + GEMM_FMLA2(C2FH,C2LH,PT,Z2FH,Z2LH,ZSCALE) \ + GEMM_FMLA2(C3FH,C3LH,PT,Z3FH,Z3LH,ZSCALE) \ + GEMM_FMLA2(C4FH,C4LH,PT,Z4FH,Z4LH,ZSCALE) + +#define GEMM_C_FMAD_UKER(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z2FH,Z2LH,PFH,PLH,C2FH,C2LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z3FH,Z3LH,PFH,PLH,C3FH,C3LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z4FH,Z4LH,PFH,PLH,C4FH,C4LH,ZSCALE) + +#define GEMM_C_LOAD_UKER_C(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z0FH,Z0LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z1FH,Z1LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z2FH,Z2LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z3FH,Z3LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z4FH,Z4LH,PFH,PLH,CADDR,CCS) + +#define GEMM_C_STORE_UKER_C(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z0FH,Z0LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z1FH,Z1LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z2FH,Z2LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z3FH,Z3LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z4FH,Z4LH,PFH,PLH,CADDR,CCS) + +#define GEMM_C_FMAD_LOAD_UKER_C(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C0FH,C0LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C1FH,C1LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z2FH,Z2LH,PFH,PLH,C2FH,C2LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C2FH,C2LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z3FH,Z3LH,PFH,PLH,C3FH,C3LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C3FH,C3LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z4FH,Z4LH,PFH,PLH,C4FH,C4LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C4FH,C4LH,PFH,PLH,CADDR,CCS) + +#define GEMM_C_LOAD_UKER_G(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z0FH,Z0LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z1FH,Z1LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z2FH,Z2LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z3FH,Z3LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z4FH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) + +#define GEMM_C_STORE_UKER_G(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z0FH,Z0LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z1FH,Z1LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z2FH,Z2LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z3FH,Z3LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z4FH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) + +#define GEMM_C_FMAD_LOAD_UKER_G(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE,ZIDX,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C0FH,C0LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C1FH,C1LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z2FH,Z2LH,PFH,PLH,C2FH,C2LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C2FH,C2LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z3FH,Z3LH,PFH,PLH,C3FH,C3LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C3FH,C3LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z4FH,Z4LH,PFH,PLH,C4FH,C4LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C4FH,C4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) + diff --git a/kernels/armsve/3/armsve_asm_2vx10cmplx.h b/kernels/armsve/3/armsve_asm_2vx10cmplx.h new file mode 100644 index 0000000000..1b67d0d169 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_2vx10cmplx.h @@ -0,0 +1,130 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,16) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,18) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,1) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,3) \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,BV4,BAddr,5) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,BV5,BAddr,7) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,BV6,BAddr,9) \ + GEMM_FMLA2_LD1R(C7Re,C7Im,PT,AColRe,AColIm,BV7,BAddr,11) \ + GEMM_FMLA2_LD1R(C8Re,C8Im,PT,AColRe,AColIm,BV0,BAddr,13) \ + GEMM_FMLA2_LD1R(C9Re,C9Im,PT,AColRe,AColIm,BV1,BAddr,15) \ + \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,BV2,BAddr,17) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,BV3,BAddr,19) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLX2_LD1R(C2Im,C2Re,PT,AColRe,AColIm,BV4,BAddr,0) \ + GEMM_FMLX2_LD1R(C3Im,C3Re,PT,AColRe,AColIm,BV5,BAddr,2) \ + GEMM_FMLX2_LD1R(C4Im,C4Re,PT,AColRe,AColIm,BV6,BAddr,4) \ + GEMM_FMLX2_LD1R(C5Im,C5Re,PT,AColRe,AColIm,BV7,BAddr,6) \ + GEMM_FMLX2_LD1R(C6Im,C6Re,PT,AColRe,AColIm,BV0,BAddr,8) \ + GEMM_FMLX2_LD1R(C7Im,C7Re,PT,AColRe,AColIm,BV1,BAddr,10) \ + GEMM_FMLX2_LD1R(C8Im,C8Re,PT,AColRe,AColIm,BV2,BAddr,12) \ + GEMM_FMLX2_LD1R(C9Im,C9Re,PT,AColRe,AColIm,BV3,BAddr,14) + +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BAddr,BRSBit) + +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,16) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,18) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,1) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,3) \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,BV4,BAddr,5) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,BV5,BAddr,7) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,BV6,BAddr,9) \ + GEMM_FMLA2_LD1R(C7Re,C7Im,PT,AColRe,AColIm,BV7,BAddr,11) \ + GEMM_FMLA2_LD1R(C8Re,C8Im,PT,AColRe,AColIm,BV0,BAddr,13) \ + GEMM_FMLA2_LD1R(C9Re,C9Im,PT,AColRe,AColIm,BV1,BAddr,15) \ + \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,BV2,BAddr,17) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,BV3,BAddr,19) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLX2(C2Im,C2Re,PT,AColRe,AColIm,BV4) \ + GEMM_FMLX2(C3Im,C3Re,PT,AColRe,AColIm,BV5) \ + GEMM_FMLX2(C4Im,C4Re,PT,AColRe,AColIm,BV6) \ + GEMM_FMLX2(C5Im,C5Re,PT,AColRe,AColIm,BV7) \ + GEMM_FMLX2(C6Im,C6Re,PT,AColRe,AColIm,BV0) \ + GEMM_FMLX2(C7Im,C7Re,PT,AColRe,AColIm,BV1) \ + GEMM_FMLX2(C8Im,C8Re,PT,AColRe,AColIm,BV2) \ + GEMM_FMLX2(C9Im,C9Re,PT,AColRe,AColIm,BV3) + +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BAddr,BRSBit) + +#define CLEAR_COL20(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15,Z16,Z17,Z18,Z19) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL4(Z12,Z13,Z14,Z15) \ + CLEAR_COL4(Z16,Z17,Z18,Z19) + +// Moving is always .d. +// Never use .DT here! +#define MOV_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,Z0Re,Z0Im,Z1Re,Z1Im) \ +" mov "#ZD0Re".d, "#Z0Re".d \n\t" \ +" mov "#ZD0Im".d, "#Z0Im".d \n\t" \ +" mov "#ZD1Re".d, "#Z1Re".d \n\t" \ +" mov "#ZD1Im".d, "#Z1Im".d \n\t" + +#define GEMM_FMULCMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + FMUL_COL2(ZD0Re,ZD0Im,Z0Re,Z0Im,ZFactorRe) \ + FMUL_COL2(ZD1Re,ZD1Im,Z1Re,Z1Im,ZFactorRe) \ + GEMM_FMLX2(ZD0Im,ZD0Re,PT,Z0Re,Z0Im,ZFactorIm) \ + GEMM_FMLX2(ZD1Im,ZD1Re,PT,Z1Re,Z1Im,ZFactorIm) + +#define GEMM_FMLACMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD0Re,ZD0Im,PT,Z0Re,Z0Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD1Re,ZD1Im,PT,Z1Re,Z1Im,ZFactorRe,ZFactorIm) + +#define GEMM_CCMPLX_LOAD_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_STORE_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_LOAD_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + +#define GEMM_CCMPLX_STORE_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + diff --git a/kernels/armsve/3/armsve_asm_macros.h b/kernels/armsve/3/armsve_asm_macros.h new file mode 100644 index 0000000000..9cbbeab920 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros.h @@ -0,0 +1,136 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Clang's label requirements. +#if defined(__clang__) +#define LABEL(str) " L" #str"%=: \n\t" +#define BEQ(str) "b.eq L" #str"%= \n\t" +#define BNE(str) "b.ne L" #str"%= \n\t" +#define BRANCH(str) "b L" #str"%= \n\t" +#else +#define LABEL(str) " ." #str": \n\t" +#define BEQ(str) "b.eq ." #str" \n\t" +#define BNE(str) "b.ne ." #str" \n\t" +#define BRANCH(str) "b ." #str" \n\t" +#endif + +#define CLEAR_COL2(Z0,Z1) \ +" dup "#Z0"."DT", #0 \n\t" \ +" dup "#Z1"."DT", #0 \n\t" + +#define CLEAR_COL4(Z0,Z1,Z2,Z3) \ + CLEAR_COL2(Z0,Z1) \ + CLEAR_COL2(Z2,Z3) + +#define SCALE_COL2(Z0,Z1,ZFACTOR) \ +" fmul "#Z0"."DT", "#Z0"."DT", "#ZFACTOR"."DT" \n\t" \ +" fmul "#Z1"."DT", "#Z1"."DT", "#ZFACTOR"."DT" \n\t" \ + +#define SCALE_COL4(Z0,Z1,Z2,Z3,ZFACTOR) \ + SCALE_COL2(Z0,Z1,ZFACTOR) \ + SCALE_COL2(Z2,Z3,ZFACTOR) + +// Prefetch or not. +#define PREFETCH_CONTIGUOUS_noprfm(LV,PROP,ADDR,SHIFT) +#define PREFETCH_CONTIGUOUS_prfm(LV,PROP,ADDR,SHIFT) \ +" prfm PLD"#LV""#PROP", ["#ADDR", "#SHIFT"] \n\t" + +#define GEMM_FMLA2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" fmla "#CCOLFH"."DT", "#PT"/m, "#ACOLFH"."DT", "#BV"."DT" \n\t" /* A Row 0 :VL */ \ +" fmla "#CCOLLH"."DT", "#PT"/m, "#ACOLLH"."DT", "#BV"."DT" \n\t" /* A Row VL:2VL */ + +#define GEMM_FMLA2_LD1R(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV,BADDR,NSHIFT) \ + GEMM_FMLA2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" "LD1R" "#BV"."DT", "#PT"/z, ["#BADDR", #"#NSHIFT"*"SZ"]\n\t" + +#define GEMM_FMLA2_LD1R_G_ELMFWD(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV,BELMADDR,BCSBIT) \ + GEMM_FMLA2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" "LD1R" "#BV"."DT", "#PT"/z, ["#BELMADDR"] \n\t" /* Load B */ \ +" add "#BELMADDR", "#BELMADDR", "#BCSBIT" \n\t" /* Forward B element */ + +#define GEMM_ACOL_CONTIGUOUS_LOAD(ZFH,ZLH,PFH,PLH,AADDR) \ +" "LD1" "#ZFH"."DT", "#PFH"/z, ["#AADDR"] \n\t" \ +" "LD1" "#ZLH"."DT", "#PLH"/z, ["#AADDR", #1, mul vl]\n\t" + +#define GEMM_ACOL_GATHER_LOAD(ZFH,ZLH,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) \ +" "LD1" "#ZFH"."DT", "#PFH"/z, ["#AADDR", "#ZIDX"."DT", "OFFS"]\n\t" \ +" add "#ATEMP", "#AADDR", "#AVSKIP" \n\t" \ +" "LD1" "#ZLH"."DT", "#PLH"/z, ["#ATEMP", "#ZIDX"."DT", "OFFS"]\n\t" + +// Prefetch or not. +#define GEMM_ACOL_GATHER_noprfm(LV,PROP,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) +#define GEMM_ACOL_GATHER_prfm(LV,PROP,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) \ +" "PRFG" PLD"#LV""#PROP", "#PFH", ["#AADDR", "#ZIDX"."DT", "OFFS"] \n\t" \ +" add "#ATEMP", "#AADDR", "#AVSKIP" \n\t" \ +" "PRFG" PLD"#LV""#PROP", "#PLH", ["#ATEMP", "#ZIDX"."DT", "OFFS"] \n\t" + +#define GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(ZFH,ZLH,PFH,PLH,AADDR,A4KS,ACS,ATEMP,PREFMODE) \ +" add "#ATEMP", "#AADDR", "#A4KS" \n\t" \ +" add "#AADDR", "#AADDR", "#ACS" \n\t" /* Forward A's address to the next column. */ \ + GEMM_ACOL_CONTIGUOUS_LOAD(ZFH,ZLH,PFH,PLH,AADDR) \ + PREFETCH_CONTIGUOUS_ ##PREFMODE(L1,STRM,ATEMP,0) + +#define GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(ZFH,ZLH,ZIDX,PFH,PLH,AADDR,A4KS,APS,ACS,AVSKIP,ATEMP,PREFMODEL1,PREFMODEL2) \ +" add "#ATEMP", "#AADDR", "#A4KS" \n\t" \ + GEMM_ACOL_GATHER_ ##PREFMODEL1(L1,STRM,ZIDX,PFH,PLH,ATEMP,AVSKIP,ATEMP) \ +" add "#ATEMP", "#AADDR", "#APS" \n\t" \ + GEMM_ACOL_GATHER_ ##PREFMODEL2(L2,STRM,ZIDX,PFH,PLH,ATEMP,AVSKIP,ATEMP) \ +" add "#AADDR", "#AADDR", "#ACS" \n\t" /* Forward A's address to the next column. */ \ + GEMM_ACOL_GATHER_LOAD(ZFH,ZLH,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) + +#define GEMM_CCOL_CONTIGUOUS_LOAD_FWD(ZFH,ZLH,PFH,PLH,CADDR,CCS) \ + GEMM_ACOL_CONTIGUOUS_LOAD(ZFH,ZLH,PFH,PLH,CADDR) \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" /* Forward C address (load) to next column. */ + +#define GEMM_CCOL_CONTIGUOUS_STORE_FWD(ZFH,ZLH,PFH,PLH,CADDR,CCS) \ +" "ST1" "#ZFH"."DT", "#PFH", ["#CADDR"] \n\t" \ +" "ST1" "#ZLH"."DT", "#PLH", ["#CADDR", #1, mul vl] \n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" /* Forward C address (store) to next column. */ + +#define GEMM_CCOL_FMAD(ZFH,ZLH,PFH,PLH,CFH,CLH,ZSCALE) \ +" fmad "#ZFH"."DT", "#PFH"/m, "#ZSCALE"."DT", "#CFH"."DT" \n\t" \ +" fmad "#ZLH"."DT", "#PLH"/m, "#ZSCALE"."DT", "#CLH"."DT" \n\t" + +#define GEMM_CCOL_GATHER_LOAD_FWD(ZFH,ZLH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_ACOL_GATHER_LOAD(ZFH,ZLH,ZIDX,PFH,PLH,CADDR,CVSKIP,CTEMP) \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + +#define GEMM_CCOL_SCATTER_STORE_FWD(ZFH,ZLH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ +" "ST1" "#ZFH"."DT", "#PFH", ["#CADDR", "#ZIDX"."DT", "OFFS"]\n\t" \ +" add "#CTEMP", "#CADDR", "#CVSKIP" \n\t" \ +" "ST1" "#ZLH"."DT", "#PLH", ["#CTEMP", "#ZIDX"."DT", "OFFS"]\n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + + diff --git a/kernels/armsve/3/armsve_asm_macros_cmplx.h b/kernels/armsve/3/armsve_asm_macros_cmplx.h new file mode 100644 index 0000000000..10097700c8 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_cmplx.h @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "armsve_asm_macros.h" + +#define FMUL_COL2(ZD0,ZD1,Z0,Z1,ZFACTOR) \ +" fmul "#ZD0"."DT", "#Z0"."DT", "#ZFACTOR"."DT" \n\t" \ +" fmul "#ZD1"."DT", "#Z1"."DT", "#ZFACTOR"."DT" \n\t" \ + +#define GEMM_FMLX2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" fmla "#CCOLFH"."DT", "#PT"/m, "#ACOLFH"."DT", "#BV"."DT" \n\t" \ +" fmls "#CCOLLH"."DT", "#PT"/m, "#ACOLLH"."DT", "#BV"."DT" \n\t" + +#define GEMM_FMLX2_LD1R(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV,BADDR,NSHIFT) \ + GEMM_FMLX2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" "LD1R" "#BV"."DT", "#PT"/z, ["#BADDR", #"#NSHIFT"*"SZ"]\n\t" + +#define GEMM_FMULCMPLX(ZDRe,ZDIm,PT,Z0Re,Z0Im,Z1Re,Z1Im) \ + FMUL_COL2(ZDRe,ZDIm,Z0Re,Z0Im,Z1Re) \ + GEMM_FMLX2(ZDIm,ZDRe,PT,Z0Re,Z0Im,Z1Im) + +#define GEMM_FMLACMPLX(ZDRe,ZDIm,PT,Z0Re,Z0Im,Z1Re,Z1Im) \ + GEMM_FMLA2(ZDRe,ZDIm,PT,Z0Re,Z0Im,Z1Re) \ + GEMM_FMLX2(ZDIm,ZDRe,PT,Z0Re,Z0Im,Z1Im) + +#define GEMM_ACOLCMPLX_CONTIGUOUS_LOAD(ZRe,ZIm,PT,AAddr) \ +" "LD2" {"#ZRe"."DT", "#ZIm"."DT"}, "#PT"/z, ["#AAddr"] \n\t" + +#define GEMM_ACOLCMPLX_CONTIGUOUS_STORE(ZRe,ZIm,PT,AAddr) \ +" "ST2" {"#ZRe"."DT", "#ZIm"."DT"}, "#PT", ["#AAddr"] \n\t" + +#define GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(ZRe,ZIm,PT,AAddr,ACS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_LOAD(ZRe,ZIm,PT,AAddr) \ +" add "#AAddr", "#AAddr", "#ACS" \n\t" /* Forward A address (load) to next column. */ + +#define GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(ZRe,ZIm,PT,CAddr,CCS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(ZRe,ZIm,PT,CAddr,CCS) + +#define GEMM_ACOLCMPLX_CONTIGUOUS_STORE_FWD(ZRe,ZIm,PT,AAddr,ACS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_STORE(ZRe,ZIm,PT,AAddr) \ +" add "#AAddr", "#AAddr", "#ACS" \n\t" /* Forward A address (load) to next column. */ + +#define GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(ZRe,ZIm,PT,CAddr,CCS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_STORE_FWD(ZRe,ZIm,PT,CAddr,CCS) + +#define GEMM_CCOLCMPLX_GATHER_LOAD_FWD(ZRe,ZIm,ZIndex,PRe,PIm,CAddr,CCS,CTemp) \ +" add "#CTemp", "#CAddr", #"SZ" \n\t" /* Imaginary skip */ \ +" "LD1" "#ZRe"."DT", "#PRe"/z, ["#CAddr", "#ZIndex"."DT", "OFFS"]\n\t" \ +" "LD1" "#ZIm"."DT", "#PRe"/z, ["#CTemp", "#ZIndex"."DT", "OFFS"]\n\t" \ +" add "#CAddr", "#CAddr", "#CCS" \n\t" + +#define GEMM_CCOLCMPLX_SCATTER_STORE_FWD(ZRe,ZIm,ZIndex,PRe,PIm,CAddr,CCS,CTemp) \ +" add "#CTemp", "#CAddr", #"SZ" \n\t" /* Imaginary skip */ \ +" "ST1" "#ZRe"."DT", "#PRe", ["#CAddr", "#ZIndex"."DT", "OFFS"]\n\t" \ +" "ST1" "#ZIm"."DT", "#PRe", ["#CTemp", "#ZIndex"."DT", "OFFS"]\n\t" \ +" add "#CAddr", "#CAddr", "#CCS" \n\t" + diff --git a/frame/3/bli_l3_tapi_ba.c b/kernels/armsve/3/armsve_asm_macros_dcomplex.h similarity index 83% rename from frame/3/bli_l3_tapi_ba.c rename to kernels/armsve/3/armsve_asm_macros_dcomplex.h index 748863f844..0beb5d2316 100644 --- a/frame/3/bli_l3_tapi_ba.c +++ b/kernels/armsve/3/armsve_asm_macros_dcomplex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -30,17 +31,18 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - -#include "blis.h" - -// Include cpp macros that instantiate the API definition templates as -// omitting expert parameters. -#include "bli_tapi_ba.h" -// Define the macro protecting the typed API definitions. -#define BLIS_ENABLE_TAPI - -// Include the typed API definitions here. -#include "bli_l3_tapi.c" +*/ +// Specify to use double precision. +#define DT "d" +#define LD1 "ld1d" +#define ST1 "st1d" +#define LD2 "ld2d" +#define ST2 "st2d" +#define LD1R "ld1rd" +#define PRFG "prfd" +#define SZ "8" +#define OFFS "lsl #3" +// Include macros. +#include "armsve_asm_macros_cmplx.h" diff --git a/kernels/armsve/3/armsve_asm_macros_double.h b/kernels/armsve/3/armsve_asm_macros_double.h new file mode 100644 index 0000000000..f93d3f3821 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_double.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Specify to use double precision. +#define DT "d" +#define LD1 "ld1d" +#define ST1 "st1d" +#define LD1R "ld1rd" +#define PRFG "prfd" +#define SZ "8" +#define OFFS "lsl #3" +// Include macros. +#include "armsve_asm_macros.h" + diff --git a/frame/3/bli_l3_oapi_ba.c b/kernels/armsve/3/armsve_asm_macros_scomplex.h similarity index 83% rename from frame/3/bli_l3_oapi_ba.c rename to kernels/armsve/3/armsve_asm_macros_scomplex.h index d6e3b2f3d5..f49cfedfba 100644 --- a/frame/3/bli_l3_oapi_ba.c +++ b/kernels/armsve/3/armsve_asm_macros_scomplex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -30,17 +31,18 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - -#include "blis.h" - -// Include cpp macros that instantiate the API definition templates as -// omitting expert parameters. -#include "bli_oapi_ba.h" -// Define the macro protecting the object API definitions. -#define BLIS_ENABLE_OAPI - -// Include the object API definitions here. -#include "bli_l3_oapi.c" +*/ +// Specify to use single precision. +#define DT "s" +#define LD1 "ld1w" +#define ST1 "st1w" +#define LD2 "ld2w" +#define ST2 "st2w" +#define LD1R "ld1rw" +#define PRFG "prfw" +#define SZ "4" +#define OFFS "uxtw #2" +// Include macros. +#include "armsve_asm_macros_cmplx.h" diff --git a/kernels/armsve/3/armsve_asm_macros_single.h b/kernels/armsve/3/armsve_asm_macros_single.h new file mode 100644 index 0000000000..2203de3453 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_single.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Specify to use single precision. +#define DT "s" +#define LD1 "ld1w" +#define ST1 "st1w" +#define LD1R "ld1rw" +#define PRFG "prfw" +#define SZ "4" +#define OFFS "uxtw #2" +// Include macros. +#include "armsve_asm_macros.h" + diff --git a/kernels/armsve/3/bli_armsve_utils.c b/kernels/armsve/3/bli_armsve_utils.c new file mode 100644 index 0000000000..2ebafa655d --- /dev/null +++ b/kernels/armsve/3/bli_armsve_utils.c @@ -0,0 +1,99 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +dim_t bli_vl_bytes_armsve(void) +{ \ + uint64_t vl = 0; + __asm__ ( + " mov x0, xzr \n\t" + " incb x0 \n\t" + " mov %[vl], x0 \n\t" + : [vl] "=r" (vl) + : + : "x0" + ); + return vl; +} + + +#define EXPANDMAC_BLKSZ_ARMSVE(ch, S_Data) \ +void PASTEMAC(ch, _blksz_armsve) (dim_t *m_r_, dim_t *n_r_, \ + dim_t *k_c_, dim_t *m_c_, dim_t *n_c_) \ +{ \ + dim_t W_L1 = bli_env_get_var("BLIS_SVE_W_L1", W_L1_SVE_DEFAULT); \ + dim_t N_L1 = bli_env_get_var("BLIS_SVE_N_L1", N_L1_SVE_DEFAULT); \ + dim_t C_L1 = bli_env_get_var("BLIS_SVE_C_L1", C_L1_SVE_DEFAULT); \ + dim_t W_L2 = bli_env_get_var("BLIS_SVE_W_L2", W_L2_SVE_DEFAULT); \ + dim_t N_L2 = bli_env_get_var("BLIS_SVE_N_L2", N_L2_SVE_DEFAULT); \ + dim_t C_L2 = bli_env_get_var("BLIS_SVE_C_L2", C_L2_SVE_DEFAULT); \ + dim_t W_L3 = bli_env_get_var("BLIS_SVE_W_L3", W_L3_SVE_DEFAULT); \ + dim_t N_L3 = bli_env_get_var("BLIS_SVE_N_L3", N_L3_SVE_DEFAULT); \ + dim_t C_L3 = bli_env_get_var("BLIS_SVE_C_L3", C_L3_SVE_DEFAULT); \ +\ + dim_t vl_b = bli_vl_bytes_armsve(); \ + dim_t vl = vl_b / S_Data; \ + dim_t m_r = 2 * vl; \ + dim_t n_r = 10; \ +\ + dim_t k_c = (dim_t)( floor((W_L1 - 1.0)/(1.0 + (double)n_r/m_r)) * N_L1 * C_L1 ) \ + / (n_r * S_Data); \ +\ + dim_t C_Ac = W_L2 - 1 - ceil( (2.0 * k_c * n_r * S_Data)/(C_L2 * N_L2) ); \ + dim_t m_c = C_Ac * (N_L2 * C_L2)/(k_c * S_Data); \ + m_c -= m_c % m_r; \ +\ + dim_t C_Bc = W_L3 - 1 - ceil( (2.0 * k_c * m_c * S_Data)/(C_L3 * N_L3) ); \ + dim_t n_c = C_Bc * (N_L3 * C_L3)/(k_c * S_Data); \ + n_c -= n_c % n_r; \ +\ + /* Ensure non-zero block sizes. */ \ + m_c = bli_max(m_c, m_r); \ + n_c = bli_max(n_c, n_r); \ + k_c = bli_max(k_c, 128); \ +\ + *m_r_ = m_r; \ + *n_r_ = n_r; \ + *k_c_ = k_c; \ + *m_c_ = m_c; \ + *n_c_ = n_c; \ +} + +EXPANDMAC_BLKSZ_ARMSVE( s, 4 ) +EXPANDMAC_BLKSZ_ARMSVE( d, 8 ) +EXPANDMAC_BLKSZ_ARMSVE( c, 8 ) +EXPANDMAC_BLKSZ_ARMSVE( z, 16 ) + diff --git a/kernels/armsve/3/bli_armsve_utils.h b/kernels/armsve/3/bli_armsve_utils.h new file mode 100644 index 0000000000..6d3aab05d7 --- /dev/null +++ b/kernels/armsve/3/bli_armsve_utils.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +dim_t bli_vl_bytes_armsve(void); + +void bli_s_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); +void bli_d_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); +void bli_c_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); +void bli_z_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c new file mode 100644 index 0000000000..098d5d4b5e --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c @@ -0,0 +1,321 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Single-precision composite instructions. +#include "armsve_asm_macros_scomplex.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10cmplx.h" + + +void bli_cgemm_armsve_asm_2vx10_unindexed + ( + dim_t m, + dim_t n, + dim_t k, + scomplex* restrict alpha, + scomplex* restrict a, + scomplex* restrict b, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + GEMM_UKR_SETUP_CT( c, m, 10, false ); + + __asm__ volatile ( +" whilelo p0.s, xzr, %12 \n\t" +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incw x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #8 \n\t" // Multiply some address skips by sizeof(scomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +LABEL(LOAD_ABC) +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +BEQ(END_CCOL_PRFM) +" \n\t" +" ld1rw z20.s, p0/z, [%1, 4*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rw z21.s, p0/z, [%1, 4*2] \n\t" +" ld1rw z22.s, p0/z, [%1, 4*4] \n\t" +" ld1rw z23.s, p0/z, [%1, 4*6] \n\t" +" ld1rw z24.s, p0/z, [%1, 4*8] \n\t" +" ld1rw z25.s, p0/z, [%1, 4*10] \n\t" +" ld1rw z26.s, p0/z, [%1, 4*12] \n\t" +" ld1rw z27.s, p0/z, [%1, 4*14] \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" \n\t" +LABEL(CCOL_PRFM) +// " cmp %3, #1 \n\t" +// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +LABEL(END_CCOL_PRFM) +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp %5, #0 \n\t" // If no 4-microkernel can be applied. +BEQ(K_LEFT_LOOP) +" \n\t" +LABEL(K_MKER_LOOP) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +BEQ(FIN_MKER_LOOP) // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +BRANCH(K_MKER_LOOP) +" \n\t" +LABEL(FIN_MKER_LOOP) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +LABEL(K_LEFT_LOOP) +" cmp %6, #0 \n\t" // End of execution. +BEQ(WRITE_MEM_PREP) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" ld1rw z20.s, p0/z, [%1, 4*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rw z21.s, p0/z, [%1, 4*2] \n\t" +" ld1rw z22.s, p0/z, [%1, 4*4] \n\t" +" ld1rw z23.s, p0/z, [%1, 4*6] \n\t" +" ld1rw z24.s, p0/z, [%1, 4*8] \n\t" +" ld1rw z25.s, p0/z, [%1, 4*10] \n\t" +" ld1rw z26.s, p0/z, [%1, 4*12] \n\t" +" ld1rw z27.s, p0/z, [%1, 4*14] \n\t" +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" sub %6, %6, #1 \n\t" +BRANCH(K_LEFT_LOOP) +" \n\t" +LABEL(WRITE_MEM_PREP) +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rw z28.s, p0/z, [%7] \n\t" // Real(alpha). +" ld1rw z29.s, p0/z, [%7, 4] \n\t" // Imag(alpha). +" ld1rw z30.s, p0/z, [%8] \n\t" // Real(beta). +" ld1rw z31.s, p0/z, [%8, 4] \n\t" // Imag(beta). +" \n\t" +LABEL(PREFETCH_ABNEXT) +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +LABEL(WRITE_MEM) +" fmov s27, #1.0 \n\t" +" fcmp s29, #0.0 \n\t" // Whether Imag(alpha) == 0. +" fccmp s28, s27, 0, eq \n\t" // Whether Real(alpha) == 1. +BEQ(UNIT_ALPHA) +" \n\t" +GEMM_FMULCMPLX_COL2(z20,z21,z22,z23,p0,z0 ,z1 ,z2 ,z3 ,z28,z29) +GEMM_FMULCMPLX_COL2(z24,z25,z26,z27,p0,z4 ,z5 ,z6 ,z7 ,z28,z29) +GEMM_FMULCMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z8, z9, z10,z11,z28,z29) +GEMM_FMULCMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z28,z29) +GEMM_FMULCMPLX_COL2(z8 ,z9 ,z10,z11,p0,z16,z17,z18,z19,z28,z29) +BRANCH(WRITE_MEM_EXEC) +" \n\t" +LABEL(UNIT_ALPHA) +MOV_COL2(z20,z21,z22,z23,z0 ,z1 ,z2 ,z3 ) +MOV_COL2(z24,z25,z26,z27,z4 ,z5 ,z6 ,z7 ) +MOV_COL2(z0 ,z1 ,z2 ,z3 ,z8, z9, z10,z11) +MOV_COL2(z4 ,z5 ,z6 ,z7 ,z12,z13,z14,z15) +MOV_COL2(z8 ,z9 ,z10,z11,z16,z17,z18,z19) +" \n\t" +LABEL(WRITE_MEM_EXEC) +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +// " cmp %3, #1 \n\t" +// BNE(WRITE_MEM_G) +" \n\t" +LABEL(WRITE_MEM_C) +" fmov s29, wzr \n\t" +" fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0. +" fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0. +BEQ(ZERO_BETA_C_0_1_2_3) +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +LABEL(ZERO_BETA_C_0_1_2_3) +GEMM_CCMPLX_STORE_COL2_C(z20,z21,z22,z23,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z24,z25,z26,z27,p0,%2,%4) +" \n\t" +BEQ(ZERO_BETA_C_4_5_6_7_8_9) +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z20,z21,z22,z23,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +LABEL(ZERO_BETA_C_4_5_6_7_8_9) +GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4) +// BRANCH(END_WRITE_MEM) +// " \n\t" +// LABEL(WRITE_MEM_G) +// " add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +// " mov x3, %3 \n\t" // s.t. 2*sizeof(float) = 2*4 = 8. +// " index z28.s, wzr, w3 \n\t" +// " fmov s29, wzr \n\t" +// " fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0. +// " fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0. +// BEQ(ZERO_BETA_G_0_1_2_3) +// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +// GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +// GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +// LABEL(ZERO_BETA_G_0_1_2_3) +// GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16) +// GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16) +// " \n\t" +// BEQ(ZERO_BETA_G_4_5_6_7_8_9) +// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +// GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16) +// GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +// GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +// GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +// LABEL(ZERO_BETA_G_4_5_6_7_8_9) +// GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16) +// GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16) +// GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) +// " \n\t" +// LABEL(END_WRITE_MEM) +// BRANCH(END_EXEC) +" \n\t" +LABEL(END_EXEC) +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: "r" (m) // %12 +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + + GEMM_UKR_FLUSH_CT( c ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c new file mode 100644 index 0000000000..0ee470f240 --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c @@ -0,0 +1,322 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_double.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10.h" + + +void bli_dgemm_armsve_asm_2vx10_unindexed + ( + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + GEMM_UKR_SETUP_CT( d, m, 10, false ); + + __asm__ volatile ( +" mov x0, xzr \n\t" +" ldr x1, %[m] \n\t" +" whilelo p0.d, x0, x1 \n\t" " incd x0 \n\t" +" whilelo p1.d, x0, x1 \n\t" +" \n\t" +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #2 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +// " ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x8, 0x3 \n\t" // Tag C address. +" lsl x8, x8, #56 \n\t" +" orr x5, x5, x8 \n\t" +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +LABEL(LOAD_ABC) +" cmp x4, #0 \n\t" // Don't preload if no microkernel there. +BEQ(END_CCOL_PRFM) + +" ld1rd z20.d, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rd z21.d, p0/z, [x1, 8] \n\t" +" ld1rd z22.d, p0/z, [x1, 16] \n\t" +" ld1rd z23.d, p0/z, [x1, 24] \n\t" +" ld1rd z24.d, p0/z, [x1, 32] \n\t" +" ld1rd z25.d, p0/z, [x1, 40] \n\t" +" ld1rd z26.d, p0/z, [x1, 48] \n\t" +" ld1rd z27.d, p0/z, [x1, 56] \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0) +" \n\t" +LABEL(CCOL_PRFM) +// " cmp x6, #1 \n\t" +// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +LABEL(END_CCOL_PRFM) +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x4, #0 \n\t" // If no 4-microkernel can be applied +BEQ(K_LEFT_LOOP) +" \n\t" +LABEL(K_MKER_LOOP) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" subs x4, x4, #1 \n\t" // Decrease counter before final replica. +BEQ(FIN_MKER_LOOP) // Branch early to avoid reading excess mem. +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +BRANCH(K_MKER_LOOP) +" \n\t" +LABEL(FIN_MKER_LOOP) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" add x0, x0, x2 \n\t" // Forward A to fill the blank. +" \n\t" +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of execution. +BEQ(WRITE_MEM_PREP) +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0) +" ld1rd z20.d, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rd z21.d, p0/z, [x1, 8] \n\t" +" ld1rd z22.d, p0/z, [x1, 16] \n\t" +" ld1rd z23.d, p0/z, [x1, 24] \n\t" +" ld1rd z24.d, p0/z, [x1, 32] \n\t" +" ld1rd z25.d, p0/z, [x1, 40] \n\t" +" ld1rd z26.d, p0/z, [x1, 48] \n\t" +" ld1rd z27.d, p0/z, [x1, 56] \n\t" +" ld1rd z28.d, p0/z, [x1, 64] \n\t" +" ld1rd z29.d, p0/z, [x1, 72] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x0, x0, x2 \n\t" // Forward A. +" add x1, x1, x3 \n\t" // Forward B. +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +" \n\t" +LABEL(WRITE_MEM_PREP) +" \n\t" +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr x4, [x4] \n\t" // Load alpha & beta (value). +" ldr x8, [x8] \n\t" +" dup z30.d, x4 \n\t" // Broadcast alpha & beta into vectors. +" dup z31.d, x8 \n\t" +" fmov d28, #1.0 \n\t" // Prepare FP 1.0. +" fmov x16, d28 \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +#ifdef _A64FX +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" prfm PLDL1STRM, [x0] \n\t" +" prfm PLDL1STRM, [x0, 256*1] \n\t" +// " prfm PLDL2KEEP, [x0, 256*2] \n\t" +// " prfm PLDL2KEEP, [x0, 256*3] \n\t" +// " prfm PLDL2KEEP, [x0, 256*4] \n\t" +// " prfm PLDL2KEEP, [x0, 256*5] \n\t" +// " prfm PLDL2KEEP, [x0, 256*6] \n\t" +// " prfm PLDL2KEEP, [x0, 256*7] \n\t" +// " prfm PLDL2KEEP, [x0, 256*8] \n\t" +// " prfm PLDL2KEEP, [x0, 256*9] \n\t" +// " prfm PLDL2KEEP, [x0, 256*10] \n\t" +// " prfm PLDL2KEEP, [x0, 256*11] \n\t" +// " prfm PLDL2KEEP, [x0, 256*12] \n\t" +// " prfm PLDL2KEEP, [x0, 256*13] \n\t" +// " prfm PLDL2KEEP, [x0, 256*14] \n\t" +// " prfm PLDL2KEEP, [x0, 256*15] \n\t" +" prfm PLDL1STRM, [x1] \n\t" +" prfm PLDL1STRM, [x1, 256*1] \n\t" +// " prfm PLDL2KEEP, [x1, 256*2] \n\t" +// " prfm PLDL2KEEP, [x1, 256*3] \n\t" +// " prfm PLDL2KEEP, [x1, 256*4] \n\t" +// " prfm PLDL2KEEP, [x1, 256*5] \n\t" +// " prfm PLDL2KEEP, [x1, 256*6] \n\t" +// " prfm PLDL2KEEP, [x1, 256*7] \n\t" +// " prfm PLDL2KEEP, [x1, 256*8] \n\t" +// " prfm PLDL2KEEP, [x1, 256*9] \n\t" +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +// " cmp x6, #1 \n\t" // Preload first half of C for contiguous case. +// BNE(WRITE_MEM) +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7) +" \n\t" +LABEL(WRITE_MEM) +" \n\t" +" cmp x16, x4 \n\t" +BEQ(UNIT_ALPHA) +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +LABEL(UNIT_ALPHA) +// " cmp x6, #1 \n\t" +// BNE(WRITE_MEM_G) +" \n\t" +LABEL(WRITE_MEM_C) +" \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +" fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN. +BEQ(BETA_ZERO_C) +// First half of C is already loaded in this case. +// GEMM_C_FMAD_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31,x9,x7) +GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7) +GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) +" \n\t" +LABEL(BETA_ZERO_C) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7) +GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7) +// BRANCH(END_WRITE_MEM) +// " \n\t" +// LABEL(END_WRITE_MEM) +// BRANCH(END_EXEC) +// " \n\t" +// LABEL(END_ERROR) +// " mov x0, #1 \n\t" // Return error. +LABEL(END_EXEC) +" mov x0, #0 \n\t" // Return normal. +: +: [m] "m" (m), + [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + + GEMM_UKR_FLUSH_CT( d ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c new file mode 100644 index 0000000000..d03af59230 --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c @@ -0,0 +1,309 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + Copyright (C) 2019, Forschunszentrum Juelich + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Single-precision composite instructions. +#include "armsve_asm_macros_single.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10.h" + + +void bli_sgemm_armsve_asm_2vx10_unindexed + ( + dim_t m, + dim_t n, + dim_t k, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + GEMM_UKR_SETUP_CT( s, m, 10, false ); + + __asm__ volatile ( +" mov x0, xzr \n\t" +" ldr x1, %[m] \n\t" +" whilelo p0.s, x0, x1 \n\t" " incw x0 \n\t" +" whilelo p1.s, x0, x1 \n\t" +" \n\t" +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incw x2, ALL, MUL #2 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +// " ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x8, 0x3 \n\t" // Tag C address. +" lsl x8, x8, #56 \n\t" +" orr x5, x5, x8 \n\t" +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" \n\t" +" mov x8, #4 \n\t" // Multiply some address skips by sizeof(float). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +LABEL(LOAD_ABC) +" cmp x4, #0 \n\t" // Don't preload if no microkernel there. +BEQ(END_CCOL_PRFM) + +" ld1rw z20.s, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rw z21.s, p0/z, [x1, 4] \n\t" +" ld1rw z22.s, p0/z, [x1, 8] \n\t" +" ld1rw z23.s, p0/z, [x1, 12] \n\t" +" ld1rw z24.s, p0/z, [x1, 16] \n\t" +" ld1rw z25.s, p0/z, [x1, 20] \n\t" +" ld1rw z26.s, p0/z, [x1, 24] \n\t" +" ld1rw z27.s, p0/z, [x1, 28] \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0) +" \n\t" +LABEL(CCOL_PRFM) +// " cmp x6, #1 \n\t" +// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +LABEL(END_CCOL_PRFM) +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x4, #0 \n\t" // If no 4-microkernel can be applied +BEQ(K_LEFT_LOOP) +" \n\t" +LABEL(K_MKER_LOOP) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" subs x4, x4, #1 \n\t" // Decrease counter before final replica. +BEQ(FIN_MKER_LOOP) // Branch early to avoid reading excess mem. +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +BRANCH(K_MKER_LOOP) +" \n\t" +LABEL(FIN_MKER_LOOP) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" add x0, x0, x2 \n\t" // Forward A to fill the blank. +" \n\t" +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of execution. +BEQ(WRITE_MEM_PREP) +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0) +" ld1rw z20.s, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rw z21.s, p0/z, [x1, 4] \n\t" +" ld1rw z22.s, p0/z, [x1, 8] \n\t" +" ld1rw z23.s, p0/z, [x1, 12] \n\t" +" ld1rw z24.s, p0/z, [x1, 16] \n\t" +" ld1rw z25.s, p0/z, [x1, 20] \n\t" +" ld1rw z26.s, p0/z, [x1, 24] \n\t" +" ld1rw z27.s, p0/z, [x1, 28] \n\t" +" ld1rw z28.s, p0/z, [x1, 32] \n\t" +" ld1rw z29.s, p0/z, [x1, 36] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x0, x0, x2 \n\t" // Forward A. +" add x1, x1, x3 \n\t" // Forward B. +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +" \n\t" +LABEL(WRITE_MEM_PREP) +" \n\t" +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr w4, [x4] \n\t" // Load alpha & beta (value). +" ldr w8, [x8] \n\t" +" dup z30.s, w4 \n\t" // Broadcast alpha & beta into vectors. +" dup z31.s, w8 \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL2KEEP, [x0] \n\t" +" prfm PLDL2KEEP, [x0, 256*1] \n\t" +" prfm PLDL2KEEP, [x0, 256*2] \n\t" +" prfm PLDL2KEEP, [x0, 256*3] \n\t" +" prfm PLDL2KEEP, [x0, 256*4] \n\t" +" prfm PLDL2KEEP, [x0, 256*5] \n\t" +" prfm PLDL2KEEP, [x0, 256*6] \n\t" +" prfm PLDL2KEEP, [x0, 256*7] \n\t" +" prfm PLDL2KEEP, [x0, 256*8] \n\t" +" prfm PLDL2KEEP, [x0, 256*9] \n\t" +" prfm PLDL2KEEP, [x0, 256*10] \n\t" +" prfm PLDL2KEEP, [x0, 256*11] \n\t" +" prfm PLDL2KEEP, [x0, 256*12] \n\t" +" prfm PLDL2KEEP, [x0, 256*13] \n\t" +" prfm PLDL2KEEP, [x0, 256*14] \n\t" +" prfm PLDL2KEEP, [x0, 256*15] \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" \n\t" +LABEL(WRITE_MEM) +" \n\t" +" fmov s28, #1.0 \n\t" +" fmov w16, s28 \n\t" +" cmp w16, w4 \n\t" +BEQ(UNIT_ALPHA) +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +LABEL(UNIT_ALPHA) +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +// " cmp x6, #1 \n\t" +// BNE(WRITE_MEM_G) +" \n\t" +LABEL(WRITE_MEM_C) +" \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +" fcmp s31, #0.0 \n\t" +BEQ(BETA_ZERO_C) +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7) +GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7) +GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) +" \n\t" +LABEL(BETA_ZERO_C) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7) +GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7) +// BRANCH(END_WRITE_MEM) +// " \n\t" +// LABEL(END_WRITE_MEM) +// BRANCH(END_EXEC) +// " \n\t" +// LABEL(END_ERROR) +// " mov x0, #1 \n\t" // Return error. +LABEL(END_EXEC) +" mov x0, #0 \n\t" // Return normal. +: +: [m] "m" (m), + [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + + GEMM_UKR_FLUSH_CT( s ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c new file mode 100644 index 0000000000..8636a527ba --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c @@ -0,0 +1,320 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_dcomplex.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10cmplx.h" + + +void bli_zgemm_armsve_asm_2vx10_unindexed + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + GEMM_UKR_SETUP_CT( z, m, 10, false ); + + __asm__ volatile ( +" whilelo p0.d, xzr, %12 \n\t" +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #16 \n\t" // Multiply some address skips by sizeof(dcomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +LABEL(LOAD_ABC) +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +BEQ(END_CCOL_PRFM) +" \n\t" +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" \n\t" +LABEL(CCOL_PRFM) +// " cmp %3, #1 \n\t" +// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +LABEL(END_CCOL_PRFM) +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp %5, #0 \n\t" // If no 4-microkernel can be applied. +BEQ(K_LEFT_LOOP) +" \n\t" +LABEL(K_MKER_LOOP) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +BEQ(FIN_MKER_LOOP) // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +BRANCH(K_MKER_LOOP) +" \n\t" +LABEL(FIN_MKER_LOOP) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +LABEL(K_LEFT_LOOP) +" cmp %6, #0 \n\t" // End of execution. +BEQ(WRITE_MEM_PREP) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" sub %6, %6, #1 \n\t" +BRANCH(K_LEFT_LOOP) +" \n\t" +LABEL(WRITE_MEM_PREP) +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rd z28.d, p0/z, [%7] \n\t" // Real(alpha). +" ld1rd z29.d, p0/z, [%7, 8] \n\t" // Imag(alpha). +" ld1rd z30.d, p0/z, [%8] \n\t" // Real(beta). +" ld1rd z31.d, p0/z, [%8, 8] \n\t" // Imag(beta). +" \n\t" +LABEL(PREFETCH_ABNEXT) +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +LABEL(WRITE_MEM) +" fmov d27, #1.0 \n\t" +" fcmp d29, #0.0 \n\t" // Whether Imag(alpha) == 0. +" fccmp d28, d27, 0, eq \n\t" // Whether Real(alpha) == 1. +BEQ(UNIT_ALPHA) +" \n\t" +GEMM_FMULCMPLX_COL2(z20,z21,z22,z23,p0,z0 ,z1 ,z2 ,z3 ,z28,z29) +GEMM_FMULCMPLX_COL2(z24,z25,z26,z27,p0,z4 ,z5 ,z6 ,z7 ,z28,z29) +GEMM_FMULCMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z8, z9, z10,z11,z28,z29) +GEMM_FMULCMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z28,z29) +GEMM_FMULCMPLX_COL2(z8 ,z9 ,z10,z11,p0,z16,z17,z18,z19,z28,z29) +BRANCH(WRITE_MEM_EXEC) +" \n\t" +LABEL(UNIT_ALPHA) +MOV_COL2(z20,z21,z22,z23,z0 ,z1 ,z2 ,z3 ) +MOV_COL2(z24,z25,z26,z27,z4 ,z5 ,z6 ,z7 ) +MOV_COL2(z0 ,z1 ,z2 ,z3 ,z8, z9, z10,z11) +MOV_COL2(z4 ,z5 ,z6 ,z7 ,z12,z13,z14,z15) +MOV_COL2(z8 ,z9 ,z10,z11,z16,z17,z18,z19) +" \n\t" +LABEL(WRITE_MEM_EXEC) +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +// " cmp %3, #1 \n\t" +// BNE(WRITE_MEM_G) +" \n\t" +LABEL(WRITE_MEM_C) +" fmov d29, xzr \n\t" +" fcmp d31, #0.0 \n\t" // Whether Imag(beta) == 0. +" fccmp d30, d29, 0, eq \n\t" // Whether Real(beta) == 0. +BEQ(ZERO_BETA_C_0_1_2_3) +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +LABEL(ZERO_BETA_C_0_1_2_3) +GEMM_CCMPLX_STORE_COL2_C(z20,z21,z22,z23,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z24,z25,z26,z27,p0,%2,%4) +" \n\t" +BEQ(ZERO_BETA_C_4_5_6_7_8_9) +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z20,z21,z22,z23,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +LABEL(ZERO_BETA_C_4_5_6_7_8_9) +GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4) +// BRANCH(END_WRITE_MEM) +// " \n\t" +// LABEL(WRITE_MEM_G) +// " add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +// " index z28.d, xzr, %3 \n\t" // s.t. 2*sizeof(double) = 2*8 = 16. +// " fmov d29, xzr \n\t" +// " fcmp d31, #0.0 \n\t" // Whether Imag(beta) == 0. +// " fccmp d30, d29, 0, eq \n\t" // Whether Real(beta) == 0. +// BEQ(ZERO_BETA_G_0_1_2_3) +// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +// GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +// GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +// LABEL(ZERO_BETA_G_0_1_2_3) +// GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16) +// GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16) +// " \n\t" +// BEQ(ZERO_BETA_G_4_5_6_7_8_9) +// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +// GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16) +// GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +// GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +// GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +// LABEL(ZERO_BETA_G_4_5_6_7_8_9) +// GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16) +// GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16) +// GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) +// " \n\t" +// LABEL(END_WRITE_MEM) +// BRANCH(END_EXEC) +// " \n\t" +LABEL(END_EXEC) +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: "r" (m) // %12 +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + + GEMM_UKR_FLUSH_CT( z ); +} + diff --git a/kernels/armsve/3/old/armsve_asm_2vx7cmplx.h b/kernels/armsve/3/old/armsve_asm_2vx7cmplx.h new file mode 100644 index 0000000000..43997deef4 --- /dev/null +++ b/kernels/armsve/3/old/armsve_asm_2vx7cmplx.h @@ -0,0 +1,135 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,PT,AColRe,AColIm,B0Re,B1Re,B2Re,B3Re,B4Re,B5Re,B6Re,B0Im,B1Im,B2Im,B3Im,B4Im,B5Im,B6Im,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,B0Re,BAddr,0) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,B1Re,BAddr,2) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,B2Re,BAddr,4) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,B3Re,BAddr,6) \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,B4Re,BAddr,8) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,B5Re,BAddr,10) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,B6Re,BAddr,12) \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,B0Im,BAddr,1) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,B1Im,BAddr,3) \ + GEMM_FMLX2_LD1R(C2Im,C2Re,PT,AColRe,AColIm,B2Im,BAddr,5) \ + GEMM_FMLX2_LD1R(C3Im,C3Re,PT,AColRe,AColIm,B3Im,BAddr,7) \ + GEMM_FMLX2_LD1R(C4Im,C4Re,PT,AColRe,AColIm,B4Im,BAddr,9) \ + GEMM_FMLX2_LD1R(C5Im,C5Re,PT,AColRe,AColIm,B5Im,BAddr,11) \ + GEMM_FMLX2_LD1R(C6Im,C6Re,PT,AColRe,AColIm,B6Im,BAddr,13) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" + +#define GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,PT,AColRe,AColIm,B0Re,B1Re,B2Re,B3Re,B4Re,B5Re,B6Re,B0Im,B1Im,B2Im,B3Im,B4Im,B5Im,B6Im,BAddr,BRSBit) \ + GEMM_FMLA2(C0Re,C0Im,PT,AColRe,AColIm,B0Re) \ + GEMM_FMLA2(C1Re,C1Im,PT,AColRe,AColIm,B1Re) \ + GEMM_FMLA2(C2Re,C2Im,PT,AColRe,AColIm,B2Re) \ + GEMM_FMLA2(C3Re,C3Im,PT,AColRe,AColIm,B3Re) \ + GEMM_FMLA2(C4Re,C4Im,PT,AColRe,AColIm,B4Re) \ + GEMM_FMLA2(C5Re,C5Im,PT,AColRe,AColIm,B5Re) \ + GEMM_FMLA2(C6Re,C6Im,PT,AColRe,AColIm,B6Re) \ + GEMM_FMLX2(C0Im,C0Re,PT,AColRe,AColIm,B0Im) \ + GEMM_FMLX2(C1Im,C1Re,PT,AColRe,AColIm,B1Im) \ + GEMM_FMLX2(C2Im,C2Re,PT,AColRe,AColIm,B2Im) \ + GEMM_FMLX2(C3Im,C3Re,PT,AColRe,AColIm,B3Im) \ + GEMM_FMLX2(C4Im,C4Re,PT,AColRe,AColIm,B4Im) \ + GEMM_FMLX2(C5Im,C5Re,PT,AColRe,AColIm,B5Im) \ + GEMM_FMLX2(C6Im,C6Re,PT,AColRe,AColIm,B6Im) + +#define CLEAR_COL14(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL2(Z12,Z13) + +#define GEMM_FMULCMPLX_COL7(ZD0Re,ZD0Im,ZD1Re,ZD1Im,ZD2Re,ZD2Im,ZD3Re,ZD3Im,ZD4Re,ZD4Im,ZD5Re,ZD5Im,ZD6Re,ZD6Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,ZFactorRe,ZFactorIm) \ + FMUL_COL2(ZD0Re,ZD0Im,Z0Re,Z0Im,ZFactorRe) \ + FMUL_COL2(ZD1Re,ZD1Im,Z1Re,Z1Im,ZFactorRe) \ + FMUL_COL2(ZD2Re,ZD2Im,Z2Re,Z2Im,ZFactorRe) \ + FMUL_COL2(ZD3Re,ZD3Im,Z3Re,Z3Im,ZFactorRe) \ + FMUL_COL2(ZD4Re,ZD4Im,Z4Re,Z4Im,ZFactorRe) \ + FMUL_COL2(ZD5Re,ZD5Im,Z5Re,Z5Im,ZFactorRe) \ + FMUL_COL2(ZD6Re,ZD6Im,Z6Re,Z6Im,ZFactorRe) \ + GEMM_FMLX2(ZD0Im,ZD0Re,PT,Z0Re,Z0Im,ZFactorIm) \ + GEMM_FMLX2(ZD1Im,ZD1Re,PT,Z1Re,Z1Im,ZFactorIm) \ + GEMM_FMLX2(ZD2Im,ZD2Re,PT,Z2Re,Z2Im,ZFactorIm) \ + GEMM_FMLX2(ZD3Im,ZD3Re,PT,Z3Re,Z3Im,ZFactorIm) \ + GEMM_FMLX2(ZD4Im,ZD4Re,PT,Z4Re,Z4Im,ZFactorIm) \ + GEMM_FMLX2(ZD5Im,ZD5Re,PT,Z5Re,Z5Im,ZFactorIm) \ + GEMM_FMLX2(ZD6Im,ZD6Re,PT,Z6Re,Z6Im,ZFactorIm) + +#define GEMM_FMLACMPLX_COL7(ZD0Re,ZD0Im,ZD1Re,ZD1Im,ZD2Re,ZD2Im,ZD3Re,ZD3Im,ZD4Re,ZD4Im,ZD5Re,ZD5Im,ZD6Re,ZD6Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD0Re,ZD0Im,PT,Z0Re,Z0Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD1Re,ZD1Im,PT,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD2Re,ZD2Im,PT,Z2Re,Z2Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD3Re,ZD3Im,PT,Z3Re,Z3Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD4Re,ZD4Im,PT,Z4Re,Z4Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD5Re,ZD5Im,PT,Z5Re,Z5Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD6Re,ZD6Im,PT,Z6Re,Z6Im,ZFactorRe,ZFactorIm) + +#define GEMM_CCMPLX_LOAD_COL7_C(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z2Re,Z2Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z3Re,Z3Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z4Re,Z4Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z5Re,Z5Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z6Re,Z6Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_STORE_COL7_C(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z2Re,Z2Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z3Re,Z3Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z4Re,Z4Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z5Re,Z5Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z6Re,Z6Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_LOAD_COL7_G(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z2Re,Z2Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z3Re,Z3Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z4Re,Z4Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z5Re,Z5Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z6Re,Z6Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + +#define GEMM_CCMPLX_STORE_COL7_G(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z2Re,Z2Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z3Re,Z3Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z4Re,Z4Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z5Re,Z5Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z6Re,Z6Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + diff --git a/kernels/armsve/3/old/armsve_asm_2vx8cmplx.h b/kernels/armsve/3/old/armsve_asm_2vx8cmplx.h new file mode 100644 index 0000000000..16711930a4 --- /dev/null +++ b/kernels/armsve/3/old/armsve_asm_2vx8cmplx.h @@ -0,0 +1,116 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,9) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,11) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,13) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,15) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,BV4,BAddr,0) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,BV5,BAddr,2) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,BV6,BAddr,4) \ + GEMM_FMLA2_LD1R(C7Re,C7Im,PT,AColRe,AColIm,BV7,BAddr,6) \ + \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,BV8,BAddr,8) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,BV9,BAddr,10) \ + GEMM_FMLX2_LD1R(C2Im,C2Re,PT,AColRe,AColIm,BV10,BAddr,12) \ + GEMM_FMLX2_LD1R(C3Im,C3Re,PT,AColRe,AColIm,BV11,BAddr,14) \ + GEMM_FMLX2_LD1R(C4Im,C4Re,PT,AColRe,AColIm,BV0,BAddr,1) \ + GEMM_FMLX2_LD1R(C5Im,C5Re,PT,AColRe,AColIm,BV1,BAddr,3) \ + GEMM_FMLX2_LD1R(C6Im,C6Re,PT,AColRe,AColIm,BV2,BAddr,5) \ + GEMM_FMLX2_LD1R(C7Im,C7Re,PT,AColRe,AColIm,BV3,BAddr,7) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_2(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BV0,BV1,BV2,BV3,BAddr,BRSBit) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV8,BV9,BV10,BV11,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,9) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,11) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,13) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,15) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLA2(C4Re,C4Im,PT,AColRe,AColIm,BV4) \ + GEMM_FMLA2(C5Re,C5Im,PT,AColRe,AColIm,BV5) \ + GEMM_FMLA2(C6Re,C6Im,PT,AColRe,AColIm,BV6) \ + GEMM_FMLA2(C7Re,C7Im,PT,AColRe,AColIm,BV7) \ + \ + GEMM_FMLX2(C0Im,C0Re,PT,AColRe,AColIm,BV8) \ + GEMM_FMLX2(C1Im,C1Re,PT,AColRe,AColIm,BV9) \ + GEMM_FMLX2(C2Im,C2Re,PT,AColRe,AColIm,BV10) \ + GEMM_FMLX2(C3Im,C3Re,PT,AColRe,AColIm,BV11) \ + GEMM_FMLX2(C4Im,C4Re,PT,AColRe,AColIm,BV0) \ + GEMM_FMLX2(C5Im,C5Re,PT,AColRe,AColIm,BV1) \ + GEMM_FMLX2(C6Im,C6Re,PT,AColRe,AColIm,BV2) \ + GEMM_FMLX2(C7Im,C7Re,PT,AColRe,AColIm,BV3) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV8,BV9,BV10,BV11,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) + +#define CLEAR_COL16(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL4(Z12,Z13,Z14,Z15) + +#define GEMM_FMULCMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + FMUL_COL2(ZD0Re,ZD0Im,Z0Re,Z0Im,ZFactorRe) \ + FMUL_COL2(ZD1Re,ZD1Im,Z1Re,Z1Im,ZFactorRe) \ + GEMM_FMLX2(ZD0Im,ZD0Re,PT,Z0Re,Z0Im,ZFactorIm) \ + GEMM_FMLX2(ZD1Im,ZD1Re,PT,Z1Re,Z1Im,ZFactorIm) + +#define GEMM_FMLACMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD0Re,ZD0Im,PT,Z0Re,Z0Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD1Re,ZD1Im,PT,Z1Re,Z1Im,ZFactorRe,ZFactorIm) + +#define GEMM_CCMPLX_LOAD_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_STORE_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_LOAD_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + +#define GEMM_CCMPLX_STORE_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + diff --git a/kernels/armsve/3/old/armsve_asm_macros_half.h b/kernels/armsve/3/old/armsve_asm_macros_half.h new file mode 100644 index 0000000000..9a46763ef2 --- /dev/null +++ b/kernels/armsve/3/old/armsve_asm_macros_half.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Specify to use half precision. +#define DT "h" +#define LD1 "ld1h" +#define ST1 "st1h" +#define LD1R "ld1rh" +#define PRFG "prfh" +#define SZ "2" +// #define OFFS UNSUPPORTED +// Include macros. +#include "armsve_asm_macros.h" + diff --git a/kernels/armsve/3/old/bli_gemm_armsve256_asm_d8x8.c b/kernels/armsve/3/old/bli_gemm_armsve256_asm_d8x8.c new file mode 100644 index 0000000000..01bb644b12 --- /dev/null +++ b/kernels/armsve/3/old/bli_gemm_armsve256_asm_d8x8.c @@ -0,0 +1,809 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" + +/* + o 8x8 Double precision micro-kernel + o Runnable on ARMv8a with SVE 256 feature, compiled with aarch64 GCC. + o Tested on qemu-aarch64 and armie for SVE. + + Preconditions: + - to use this kernel, SVE with vector length of 256 bits is a must. + + April 2020. +*/ +void bli_dgemm_armsve256_asm_8x8 + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +__asm__ volatile +( +" \n\t" +" ldr x0,%[aaddr] \n\t" // Load address of A +" ldr x1,%[baddr] \n\t" // Load address of B +" ldr x2,%[caddr] \n\t" // Load address of C +" \n\t" +" ldr x3,%[a_next] \n\t" // Move pointer +" ldr x4,%[b_next] \n\t" // Move pointer +" \n\t" +" ldr x5,%[k_iter] \n\t" // Init guard (k_iter) +" ldr x6,%[k_left] \n\t" // Init guard (k_iter) +" \n\t" +" ldr x7,%[alpha] \n\t" // Alpha address +" ldr x8,%[beta] \n\t" // Beta address +" \n\t" +" ldr x9,%[cs_c] \n\t" // Load cs_c +" lsl x10,x9,#3 \n\t" // cs_c * sizeof(double) +" \n\t" +" ldr x13,%[rs_c] \n\t" // Load rs_c. +" lsl x14,x13,#3 \n\t" // rs_c * sizeof(double). +" \n\t" +" add x20,x2,x10 \n\t" //Load address Column 1 of C +" add x21,x20,x10 \n\t" //Load address Column 2 of C +" add x22,x21,x10 \n\t" //Load address Column 3 of C +" add x23,x22,x10 \n\t" //Load address Column 4 of C +" add x24,x23,x10 \n\t" //Load address Column 5 of C +" add x25,x24,x10 \n\t" //Load address Column 6 of C +" add x26,x25,x10 \n\t" //Load address Column 7 of C +" \n\t" +" prfm pldl1keep,[x2] \n\t" // Prefetch c. +" prfm pldl1keep,[x20] \n\t" // Prefetch c. +" prfm pldl1keep,[x21] \n\t" // Prefetch c. +" prfm pldl1keep,[x22] \n\t" // Prefetch c. +" prfm pldl1keep,[x23] \n\t" // Prefetch c. +" prfm pldl1keep,[x24] \n\t" // Prefetch c. +" prfm pldl1keep,[x25] \n\t" // Prefetch c. +" prfm pldl1keep,[x26] \n\t" // Prefetch c. +" \n\t" +" ldr z0, [x0] \n\t" // Load a +" ldr z1, [x0, #1, MUL VL] \n\t" +" \n\t" +" ptrue p0.d, all \n\t" +" ld1rqd {z2.d}, p0/z, [x1] \n\t" // load b( l,0:1 ) +" ld1rqd {z3.d}, p0/z, [x1, #16] \n\t" // load b( l,2:3 ) +" ld1rqd {z4.d}, p0/z, [x1, #32] \n\t" // load b( l,4:5 ) +" ld1rqd {z5.d}, p0/z, [x1, #48] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" // PRFM, the following prefetch on [x1] and [x0] +" \n\t" // is for b rows 4..7 and a columns 4..7. +" \n\t" // both of them will be used in next iteration +" \n\t" // of k_iter (unrolled per 4 loops) +" \n\t" +" dup z16.d, #0 \n\t" // Vector for accummulating column 0 +" prfm PLDL1KEEP, [x1, #256] \n\t" // prefetch b row no.4 +" dup z17.d, #0 \n\t" // Vector for accummulating column 0 +" prfm PLDL1KEEP, [x1, #320] \n\t" // prefetch b row no.5 +" dup z18.d, #0 \n\t" // Vector for accummulating column 1 +" prfm PLDL1KEEP, [x1, #384] \n\t" // prefetch b row no.6 +" dup z19.d, #0 \n\t" // Vector for accummulating column 1 +" prfm PLDL1KEEP, [x1, #448] \n\t" // preftech b row no.7 +" dup z20.d, #0 \n\t" // Vector for accummulating column 2 +" dup z21.d, #0 \n\t" // Vector for accummulating column 2 +" \n\t" +" dup z22.d, #0 \n\t" // Vector for accummulating column 3 +" prfm PLDL1KEEP, [x0, #256] \n\t" // prefetch a col. no.4 +" dup z23.d, #0 \n\t" // Vector for accummulating column 3 +" prfm PLDL1KEEP, [x0, #320] \n\t" // prefetch a col. no.5 +" dup z24.d, #0 \n\t" // Vector for accummulating column 4 +" prfm PLDL1KEEP, [x0, #384] \n\t" // prefetch a col. no.6 +" dup z25.d, #0 \n\t" // Vector for accummulating column 4 +" prfm PLDL1KEEP, [x0, #448] \n\t" // prefetch a col. no.7 +" dup z26.d, #0 \n\t" // Vector for accummulating column 5 +" dup z27.d, #0 \n\t" // Vector for accummulating column 5 +" \n\t" +" dup z28.d, #0 \n\t" // Vector for accummulating column 6 +" dup z29.d, #0 \n\t" // Vector for accummulating column 6 +" dup z30.d, #0 \n\t" // Vector for accummulating column 7 +" dup z31.d, #0 \n\t" // Vector for accummulating column 7 +" \n\t" +" \n\t" +" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. +" beq .DCONSIDERKLEFT \n\t" +" \n\t" +" add x0, x0, #64 \n\t" //update address of A +" add x1, x1, #64 \n\t" //update address of B +" \n\t" +" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. +" beq .DLASTITER \n\t" // (as loop is do-while-like). +" \n\t" +" DLOOP: \n\t" // Body +" \n\t" +" fmla z16.d, z0.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" prfm PLDL1KEEP, [x1, #448] \n\t" // prefetch b row no.8, 512-64=448 +" fmla z17.d, z1.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" prfm PLDL1KEEP, [x1, #512] \n\t" // prefetch b row no.9 +" fmla z18.d, z0.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" prfm PLDL1KEEP, [x1, #576] \n\t" // prefetch b row no.10 +" \n\t" +" fmla z19.d, z1.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z0.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" ldr z6, [x0] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z21.d, z1.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" fmla z22.d, z0.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" ldr z7, [x0, #1, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z23.d, z1.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z0.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z1.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z0.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z1.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #16] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z0.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z1.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #32] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z0.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z1.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #48] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" // End it 1 +" \n\t" +" fmla z16.d, z6.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" prfm PLDL1KEEP, [x1, #640] \n\t" // prefetch b row no.11 +" fmla z17.d, z7.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" prfm PLDL1KEEP, [x0, #448] \n\t" // prefetch a col. no.8 +" fmla z18.d, z6.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" prfm PLDL1KEEP, [x0, #512] \n\t" // prefetch a col. no.9 +" \n\t" +" fmla z19.d, z7.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z6.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" ldr z0, [x0, #2, MUL VL] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z21.d, z7.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" fmla z22.d, z6.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" ldr z1, [x0, #3, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z23.d, z7.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z6.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1, #64] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z7.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z6.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z7.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #80] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z6.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z7.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #96] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z6.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z7.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #112] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" +" \n\t" //End it 2 +" \n\t" +" fmla z16.d, z0.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" prfm PLDL1KEEP, [x0, #576] \n\t" // prefetch a col. no.10 +" fmla z17.d, z1.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" prfm PLDL1KEEP, [x0, #640] \n\t" // prefetch a col. no.11 +" \n\t" +" fmla z18.d, z0.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" \n\t" +" add x1, x1, #128 \n\t" // because immediate in 'ldr1rqd' must be +" \n\t" // in range -128 to 112 +" \n\t" +" fmla z19.d, z1.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z0.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" ldr z6, [x0, #4, MUL VL] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z21.d, z1.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" fmla z22.d, z0.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" ldr z7, [x0, #5, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z23.d, z1.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z0.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1, #0] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z1.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z0.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z1.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #16] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z0.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z1.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #32] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z0.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z1.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #48] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" // End it 3 +" \n\t" +" fmla z16.d, z6.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" fmla z17.d, z7.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" fmla z18.d, z6.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" ldr z0, [x0, #6, MUL VL] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z19.d, z7.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z6.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" fmla z21.d, z7.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" ldr z1, [x0, #7, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z22.d, z6.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" fmla z23.d, z7.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z6.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1, #64] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z7.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z6.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z7.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #80] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z6.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z7.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #96] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z6.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z7.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #112] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" //End it 4 +" add x0, x0, #256 \n\t" +" add x1, x1, #128 \n\t" +" \n\t" +" sub x5,x5,1 \n\t" // i-=1 +" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. +" bne DLOOP \n\t" +" \n\t" +".DLASTITER: \n\t" +" \n\t" +" fmla z16.d, z0.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" fmla z17.d, z1.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" fmla z18.d, z0.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" ldr z6, [x0] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z19.d, z1.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z0.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" fmla z21.d, z1.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" ldr z7, [x0, #1, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z22.d, z0.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" fmla z23.d, z1.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z0.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z1.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z0.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z1.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #16] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z0.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z1.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #32] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z0.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z1.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #48] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" // End it 1 +" \n\t" +" fmla z16.d, z6.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" fmla z17.d, z7.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" fmla z18.d, z6.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" ldr z0, [x0, #2, MUL VL] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z19.d, z7.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z6.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" fmla z21.d, z7.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" ldr z1, [x0, #3, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z22.d, z6.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" fmla z23.d, z7.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z6.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1, #64] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z7.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z6.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z7.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #80] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z6.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z7.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #96] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z6.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z7.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #112] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" +" \n\t" //End it 2 +" \n\t" +" fmla z16.d, z0.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" fmla z17.d, z1.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" fmla z18.d, z0.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" ldr z6, [x0, #4, MUL VL] \n\t" // Load a( 0:3,l ) +" \n\t" +" fmla z19.d, z1.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" fmla z20.d, z0.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" fmla z21.d, z1.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" ldr z7, [x0, #5, MUL VL] \n\t" // load a( 4:7,l ) +" \n\t" +" fmla z22.d, z0.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" add x1, x1, #128 \n\t" // because immediate in 'ldr1rqd' must be +" \n\t" // in range -128 to 112 +" fmla z23.d, z1.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" fmla z24.d, z0.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" ld1rqd {z2.d}, p0/z, [x1, #0] \n\t" // load b( l,0:1 ) +" \n\t" +" fmla z25.d, z1.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" fmla z26.d, z0.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z1.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" ld1rqd {z3.d}, p0/z, [x1, #16] \n\t" // load b( l,2:3 ) +" \n\t" +" fmla z28.d, z0.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z1.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" ld1rqd {z4.d}, p0/z, [x1, #32] \n\t" // load b( l,4:5 ) +" \n\t" +" fmla z30.d, z0.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z1.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" ld1rqd {z5.d}, p0/z, [x1, #48] \n\t" // load b( l,6:7 ) +" \n\t" +" \n\t" // End it 3 +" \n\t" +" fmla z16.d, z6.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" fmla z17.d, z7.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" \n\t" +" fmla z18.d, z6.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" fmla z19.d, z7.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" \n\t" +" fmla z20.d, z6.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" fmla z21.d, z7.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" \n\t" +" fmla z22.d, z6.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" fmla z23.d, z7.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" \n\t" +" fmla z24.d, z6.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" fmla z25.d, z7.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" \n\t" +" fmla z26.d, z6.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z7.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" add x1, x1, #64 \n\t" +" \n\t" +" fmla z28.d, z6.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z7.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" \n\t" +" fmla z30.d, z6.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z7.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" \n\t" +" \n\t" //End it 4 +" add x0, x0, #192 \n\t" +" \n\t" +" .DCONSIDERKLEFT: \n\t" +" cmp x6,0 \n\t" // If k_left == 0, we are done. +" beq .DPOSTACCUM \n\t" // else, we enter the k_left loop. +" \n\t" +".DLOOPKLEFT: \n\t" +" \n\t" +" ldr z0, [x0] \n\t" // Load a +" ldr z1, [x0, #1, MUL VL] \n\t" +" add x0, x0, #64 \n\t" +" \n\t" +" ld1rqd {z2.d}, p0/z, [x1] \n\t" // load b( l,0:1 ) +" ld1rqd {z3.d}, p0/z, [x1, #16] \n\t" // load b( l,2:3 ) +" ld1rqd {z4.d}, p0/z, [x1, #32] \n\t" // load b( l,4:5 ) +" ld1rqd {z5.d}, p0/z, [x1, #48] \n\t" // load b( l,6:7 ) +" add x1, x1, #64 \n\t" +" \n\t" +" sub x6,x6,1 \n\t" +" \n\t" +" fmla z16.d, z0.d, z2.d[0] \n\t" // Accummulate c(0:3,0)+=a(0:3,l)*b(l,0) +" fmla z17.d, z1.d, z2.d[0] \n\t" // Accummulate c(4:7,0)+=a(4:7,l)*b(l,0) +" \n\t" +" fmla z18.d, z0.d, z2.d[1] \n\t" // Accummulate c(0:3,1)+=a(0:3,l)*b(l,1) +" fmla z19.d, z1.d, z2.d[1] \n\t" // Accummulate c(4:7,1)+=a(4:7,l)*b(l,1) +" \n\t" +" fmla z20.d, z0.d, z3.d[0] \n\t" // Accummulate c(0:3,2)+=a(0:3,l)*b(l,2) +" fmla z21.d, z1.d, z3.d[0] \n\t" // Accummulate c(4:7,2)+=a(4:7,l)*b(l,2) +" \n\t" +" fmla z22.d, z0.d, z3.d[1] \n\t" // Accummulate c(0:3,3)+=a(0:3,l)*b(l,3) +" fmla z23.d, z1.d, z3.d[1] \n\t" // Accummulate c(4:7,3)+=a(4:7,l)*b(l,3) +" \n\t" +" fmla z24.d, z0.d, z4.d[0] \n\t" // Accummulate c(0:3,4)+=a(0:3,l)*b(l,4) +" fmla z25.d, z1.d, z4.d[0] \n\t" // Accummulate c(4:7,4)+=a(4:7,l)*b(l,4) +" \n\t" +" fmla z26.d, z0.d, z4.d[1] \n\t" // Accummulate c(0:3,5)+=a(0:3,l)*b(l,5) +" fmla z27.d, z1.d, z4.d[1] \n\t" // Accummulate c(4:7,5)+=a(0:3,l)*b(l,5) +" \n\t" +" fmla z28.d, z0.d, z5.d[0] \n\t" // Accummulate c(0:3,6)+=a(0:3,l)*b(l,6) +" fmla z29.d, z1.d, z5.d[0] \n\t" // Accummulate c(4:7,6)+=a(0:3,l)*b(l,6) +" \n\t" +" fmla z30.d, z0.d, z5.d[1] \n\t" // Accummulate c(0:3,7)+=a(0:3,l)*b(l,7) +" fmla z31.d, z1.d, z5.d[1] \n\t" // Accummulate c(4:7,7)+=a(0:3,l)*b(l,7) +" \n\t" +" cmp x6,0 \n\t" // Iterate again. +" bne .DLOOPKLEFT \n\t" // if i!=0. +" \n\t" +" .DPOSTACCUM: \n\t" +" \n\t" +" ld1rd {z6.d}, p0/z, [x7] \n\t" // Load alpha. +" ld1rd {z7.d}, p0/z, [x8] \n\t" // Load beta +" \n\t" +" cmp x13,#1 \n\t" // If rs_c != 1 (column-major) +" bne .DGENSTORED \n\t" +" \n\t" +" .DCOLSTORED: \n\t" // C is column-major. +" \n\t" +" dup z0.d, #0 \n\t" +" dup z1.d, #0 \n\t" +" dup z2.d, #0 \n\t" +" dup z3.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROCOLSTOREDS1 \n\t" // Taking care of the beta==0 case. +" \n\t" +" ldr z0, [x2] \n\t" //Load column 0 of C +" ldr z1, [x2, #1, MUL VL] \n\t" +" \n\t" +" ldr z2, [x20] \n\t" //Load column 1 of C +" ldr z3, [x20, #1, MUL VL] \n\t" +" \n\t" +" fmul z0.d, z0.d, z7.d \n\t" // Scale by beta +" fmul z1.d, z1.d, z7.d \n\t" // Scale by beta +" fmul z2.d, z2.d, z7.d \n\t" // Scale by beta +" fmul z3.d, z3.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROCOLSTOREDS1: \n\t" +" \n\t" +" fmla z0.d, z16.d, z6.d[0] \n\t" // Scale by alpha +" fmla z1.d, z17.d, z6.d[0] \n\t" // Scale by alpha +" fmla z2.d, z18.d, z6.d[0] \n\t" // Scale by alpha +" fmla z3.d, z19.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" str z0, [x2] \n\t" //Store column 0 of C +" str z1, [x2, #1, MUL VL] \n\t" +" \n\t" +" str z2, [x20] \n\t" //Store column 1 of C +" str z3, [x20, #1, MUL VL] \n\t" +" \n\t" +" dup z8.d, #0 \n\t" +" dup z9.d, #0 \n\t" +" dup z10.d, #0 \n\t" +" dup z11.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROCOLSTOREDS2 \n\t" // Taking care of the beta==0 case. +" \n\t" +" ldr z8, [x21] \n\t" //Load column 2 of C +" ldr z9, [x21, #1, MUL VL] \n\t" +" \n\t" +" ldr z10, [x22] \n\t" //Load column 3 of C +" ldr z11, [x22, #1, MUL VL] \n\t" +" \n\t" +" fmul z8.d, z8.d, z7.d \n\t" // Scale by beta +" fmul z9.d, z9.d, z7.d \n\t" // Scale by beta +" fmul z10.d, z10.d, z7.d \n\t" // Scale by beta +" fmul z11.d, z11.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROCOLSTOREDS2: \n\t" +" \n\t" +" fmla z8.d, z20.d, z6.d[0] \n\t" // Scale by alpha +" fmla z9.d, z21.d, z6.d[0] \n\t" // Scale by alpha +" fmla z10.d, z22.d, z6.d[0] \n\t" // Scale by alpha +" fmla z11.d, z23.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" str z8, [x21] \n\t" //Store column 2 of C +" str z9, [x21, #1, MUL VL] \n\t" +" \n\t" +" str z10, [x22] \n\t" //Store column 3 of C +" str z11, [x22, #1, MUL VL] \n\t" +" \n\t" +" dup z0.d, #0 \n\t" +" dup z1.d, #0 \n\t" +" dup z2.d, #0 \n\t" +" dup z3.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROCOLSTOREDS3 \n\t" // Taking care of the beta==0 case. +" \n\t" +" ldr z0, [x23] \n\t" //Load column 4 of C +" ldr z1, [x23, #1, MUL VL] \n\t" +" \n\t" +" ldr z2, [x24] \n\t" //Load column 5 of C +" ldr z3, [x24, #1, MUL VL] \n\t" +" \n\t" +" fmul z0.d, z0.d, z7.d \n\t" // Scale by beta +" fmul z1.d, z1.d, z7.d \n\t" // Scale by beta +" fmul z2.d, z2.d, z7.d \n\t" // Scale by beta +" fmul z3.d, z3.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROCOLSTOREDS3: \n\t" +" \n\t" +" fmla z0.d, z24.d, z6.d[0] \n\t" // Scale by alpha +" fmla z1.d, z25.d, z6.d[0] \n\t" // Scale by alpha +" fmla z2.d, z26.d, z6.d[0] \n\t" // Scale by alpha +" fmla z3.d, z27.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" str z0, [x23] \n\t" //Store column 4 of C +" str z1, [x23, #1, MUL VL] \n\t" +" \n\t" +" str z2, [x24] \n\t" //Store column 5 of C +" str z3, [x24, #1, MUL VL] \n\t" +" \n\t" +" dup z8.d, #0 \n\t" +" dup z9.d, #0 \n\t" +" dup z10.d, #0 \n\t" +" dup z11.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROCOLSTOREDS4 \n\t" // Taking care of the beta==0 case. +" \n\t" +" ldr z8, [x25] \n\t" //Load column 6 of C +" ldr z9, [x25, #1, MUL VL] \n\t" +" \n\t" +" ldr z10, [x26] \n\t" //Load column 7 of C +" ldr z11, [x26, #1, MUL VL] \n\t" +" \n\t" +" fmul z8.d, z8.d, z7.d \n\t" // Scale by beta +" fmul z9.d, z9.d, z7.d \n\t" // Scale by beta +" fmul z10.d, z10.d, z7.d \n\t" // Scale by beta +" fmul z11.d, z11.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROCOLSTOREDS4: \n\t" +" \n\t" +" prfm pldl2keep,[x3] \n\t" +" prfm pldl2keep,[x4] \n\t" +" \n\t" +" fmla z8.d, z28.d, z6.d[0] \n\t" // Scale by alpha +" fmla z9.d, z29.d, z6.d[0] \n\t" // Scale by alpha +" fmla z10.d, z30.d, z6.d[0] \n\t" // Scale by alpha +" fmla z11.d, z31.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" str z8, [x25] \n\t" //Store column 6 of C +" str z9, [x25, #1, MUL VL] \n\t" +" \n\t" +" str z10, [x26] \n\t" //Store column 7 of C +" str z11, [x26, #1, MUL VL] \n\t" +" \n\t" +" b .DEND \n\t" +" \n\t" +" .DGENSTORED: \n\t" // C is general-stride stored. +" \n\t" +" \n\t" // x14 is row-stride in number of bytes. +" lsl x15,x14,#2 \n\t" // x15 is 4-row-stride, which is the address offset +" \n\t" // btw c(4,*) and c(0,*) +" index z4.d, xzr, x14 \n\t" // z4 is address offsets of four contiguous elements +" \n\t" // in a column. such as c( 0:3,* ). +" \n\t" // z4 is used as vector index for gather/scatter +" \n\t" // loading/storing from column of *c +" \n\t" +" \n\t" // C's each column's address: +" \n\t" // x2, x20, x21, x22, x23, x24, x25, x26: are addresses of c(0,0:7) +" \n\t" // x5, x6, x7, x8, x16, x17, x18, x19: are addresses of c(4,0:7) +" add x5, x15, x2 \n\t" // x5 is address of c(4,0) +" add x6, x15, x20 \n\t" // x6 is address of c(4,1) +" add x7, x15, x21 \n\t" // x7 is address of c(4,2) +" add x8, x15, x22 \n\t" // x8 is address of c(4,3) +" add x16, x15, x23 \n\t" // x16 is address of c(4,4) +" add x17, x15, x24 \n\t" // x17 is address of c(4,5) +" add x18, x15, x25 \n\t" // x18 is address of c(4,6) +" add x19, x15, x26 \n\t" // x19 is address of c(4,7) +" \n\t" +" dup z0.d, #0 \n\t" // C column 0, 1 +" dup z1.d, #0 \n\t" +" dup z2.d, #0 \n\t" +" dup z3.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROGENSTOREDS1 \n\t" // Taking care of the beta==0 case. +" \n\t" +" \n\t" // x2 is address of c(0,0) +" \n\t" // x5 is address of c(4,0) +" \n\t" // x20 is address of c(0,1) +" \n\t" // x6 is address of c(4,1) +" ld1d {z0.d}, p0/z, [x2, z4.d] \n\t" // Load c( 0:3,0 ) into z0 +" ld1d {z1.d}, p0/z, [x5, z4.d] \n\t" // Load c( 4:7,0 ) into z1 +" ld1d {z2.d}, p0/z, [x20, z4.d] \n\t" // Load c( 0:3,1 ) into z2 +" ld1d {z3.d}, p0/z, [x6 , z4.d] \n\t" // Load c( 4:7,1 ) into z3 +" \n\t" +" fmul z0.d, z0.d, z7.d \n\t" // Scale by beta +" fmul z1.d, z1.d, z7.d \n\t" // Scale by beta +" fmul z2.d, z2.d, z7.d \n\t" // Scale by beta +" fmul z3.d, z3.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROGENSTOREDS1: \n\t" +" \n\t" +" fmla z0.d, z16.d, z6.d[0] \n\t" // Scale by alpha +" fmla z1.d, z17.d, z6.d[0] \n\t" // Scale by alpha +" fmla z2.d, z18.d, z6.d[0] \n\t" // Scale by alpha +" fmla z3.d, z19.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" st1d {z0.d}, p0, [x2 , z4.d] \n\t" // Store c( 0:3,0 ) <- z0 +" st1d {z1.d}, p0, [x5 , z4.d] \n\t" // Store c( 4:7,0 ) <- z1 +" st1d {z2.d}, p0, [x20, z4.d] \n\t" // Store c( 0:3,1 ) <- z2 +" st1d {z3.d}, p0, [x6 , z4.d] \n\t" // Store c( 4:7,1 ) <- z3 +" \n\t" +" \n\t" +" \n\t" +" dup z8.d, #0 \n\t" // C column 2, 3 +" dup z9.d, #0 \n\t" +" dup z10.d, #0 \n\t" +" dup z11.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROGENSTOREDS2 \n\t" // Taking care of the beta==0 case. +" \n\t" +" \n\t" // x21 is address of c(0,2) +" \n\t" // x7 is address of c(4,2) +" \n\t" // x22 is address of c(0,3) +" \n\t" // x8 is address of c(4,3) +" ld1d {z8.d}, p0/z, [x21, z4.d] \n\t" // Load c( 0:3,2 ) into z8 +" ld1d {z9.d}, p0/z, [x7 , z4.d] \n\t" // Load c( 4:7,2 ) into z9 +" ld1d {z10.d}, p0/z, [x22, z4.d] \n\t" // Load c( 0:3,3 ) into z10 +" ld1d {z11.d}, p0/z, [x8 , z4.d] \n\t" // Load c( 4:7,3 ) into z11 +" \n\t" +" fmul z8.d, z8.d, z7.d \n\t" // Scale by beta +" fmul z9.d, z9.d, z7.d \n\t" // Scale by beta +" fmul z10.d, z10.d, z7.d \n\t" // Scale by beta +" fmul z11.d, z11.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROGENSTOREDS2: \n\t" +" \n\t" +" fmla z8.d, z20.d, z6.d[0] \n\t" // Scale by alpha +" fmla z9.d, z21.d, z6.d[0] \n\t" // Scale by alpha +" fmla z10.d, z22.d, z6.d[0] \n\t" // Scale by alpha +" fmla z11.d, z23.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" st1d {z8.d}, p0, [x21, z4.d] \n\t" // Store c( 0:3,2 ) <- z8 +" st1d {z9.d}, p0, [x7 , z4.d] \n\t" // Store c( 4:7,2 ) <- z9 +" st1d {z10.d}, p0, [x22, z4.d] \n\t" // Store c( 0:3,3 ) <- z10 +" st1d {z11.d}, p0, [x8 , z4.d] \n\t" // Store c( 4:7,3 ) <- z11 +" \n\t" +" dup z0.d, #0 \n\t" // C column 4, 5 +" dup z1.d, #0 \n\t" +" dup z2.d, #0 \n\t" +" dup z3.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROGENSTOREDS3 \n\t" // Taking care of the beta==0 case. +" \n\t" +" \n\t" // x23 is address of c(0,4) +" \n\t" // x16 is address of c(4,4) +" \n\t" // x24 is address of c(0,5) +" \n\t" // x17 is address of c(4,5) +" ld1d {z0.d}, p0/z, [x23, z4.d] \n\t" // Load c( 0:3,4 ) into z0 +" ld1d {z1.d}, p0/z, [x16, z4.d] \n\t" // Load c( 4:7,4 ) into z1 +" ld1d {z2.d}, p0/z, [x24, z4.d] \n\t" // Load c( 0:3,5 ) into z2 +" ld1d {z3.d}, p0/z, [x17, z4.d] \n\t" // Load c( 4:7,5 ) into z3 +" \n\t" +" fmul z0.d, z0.d, z7.d \n\t" // Scale by beta +" fmul z1.d, z1.d, z7.d \n\t" // Scale by beta +" fmul z2.d, z2.d, z7.d \n\t" // Scale by beta +" fmul z3.d, z3.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROGENSTOREDS3: \n\t" +" \n\t" +" fmla z0.d, z24.d, z6.d[0] \n\t" // Scale by alpha +" fmla z1.d, z25.d, z6.d[0] \n\t" // Scale by alpha +" fmla z2.d, z26.d, z6.d[0] \n\t" // Scale by alpha +" fmla z3.d, z27.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" st1d {z0.d}, p0, [x23, z4.d] \n\t" // Store c( 0:3,4 ) <- z0 +" st1d {z1.d}, p0, [x16, z4.d] \n\t" // Store c( 4:7,4 ) <- z1 +" st1d {z2.d}, p0, [x24, z4.d] \n\t" // Store c( 0:3,5 ) <- z2 +" st1d {z3.d}, p0, [x17, z4.d] \n\t" // Store c( 4:7,5 ) <- z3 +" \n\t" +" dup z8.d, #0 \n\t" // C column 6, 7 +" dup z9.d, #0 \n\t" +" dup z10.d, #0 \n\t" +" dup z11.d, #0 \n\t" +" \n\t" +" fcmp d7,#0.0 \n\t" +" beq .DBETAZEROGENSTOREDS4 \n\t" // Taking care of the beta==0 case. +" \n\t" +" \n\t" // x25 is address of c(0,6) +" \n\t" // x18 is address of c(4,6) +" \n\t" // x26 is address of c(0,7) +" \n\t" // x19 is address of c(4,7) +" ld1d {z8.d}, p0/z, [x25, z4.d] \n\t" // Load c( 0:3,6 ) into z8 +" ld1d {z9.d}, p0/z, [x18, z4.d] \n\t" // Load c( 4:7,6 ) into z9 +" ld1d {z10.d}, p0/z, [x26, z4.d] \n\t" // Load c( 0:3,7 ) into z10 +" ld1d {z11.d}, p0/z, [x19, z4.d] \n\t" // Load c( 4:7,7 ) into z11 +" \n\t" +" fmul z8.d, z8.d, z7.d \n\t" // Scale by beta +" fmul z9.d, z9.d, z7.d \n\t" // Scale by beta +" fmul z10.d, z10.d, z7.d \n\t" // Scale by beta +" fmul z11.d, z11.d, z7.d \n\t" // Scale by beta +" \n\t" +" .DBETAZEROGENSTOREDS4: \n\t" +" \n\t" +" fmla z8.d, z28.d, z6.d[0] \n\t" // Scale by alpha +" fmla z9.d, z29.d, z6.d[0] \n\t" // Scale by alpha +" fmla z10.d, z30.d, z6.d[0] \n\t" // Scale by alpha +" fmla z11.d, z31.d, z6.d[0] \n\t" // Scale by alpha +" \n\t" +" st1d {z8.d}, p0, [x25, z4.d] \n\t" // Store c( 0:3,6 ) <- z8 +" st1d {z9.d}, p0, [x18, z4.d] \n\t" // Store c( 4:7,6 ) <- z9 +" st1d {z10.d}, p0, [x26, z4.d] \n\t" // Store c( 0:3,7 ) <- z10 +" st1d {z11.d}, p0, [x19, z4.d] \n\t" // Store c( 4:7,7 ) <- z11 +" \n\t" +" .DEND: \n\t" // Done! +" \n\t" +:// output operands (none) +:// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 6 + [cs_c] "m" (cs_c), // 7 + [a_next] "m" (a_next), // 8 + [b_next] "m" (b_next) // 9 +:// Register clobber list + "x0","x1","x2","x3", + "x4","x5","x6", + "x7","x8","x9", + "x10","x11","x12","x13","x14","x15","x16","x17","x18","x19", + "x20","x21","x22","x23","x24","x25","x26", + "x27", + "v0","v1","v2", + "v3","v4","v5", + "v6","v7","v8", + "v9","v10","v11", + "v12","v13","v14", + "v15","v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" +); + +} diff --git a/kernels/armsve/3/old/bli_gemm_armsve_asm_sh2vx10_unindexed.c b/kernels/armsve/3/old/bli_gemm_armsve_asm_sh2vx10_unindexed.c new file mode 100644 index 0000000000..817153bfe9 --- /dev/null +++ b/kernels/armsve/3/old/bli_gemm_armsve_asm_sh2vx10_unindexed.c @@ -0,0 +1,343 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + Copyright (C) 2019, Forschunszentrum Juelich + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Half-precision composite instructions. +#include "armsve_asm_macros_half.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10.h" + +// Gather-load / scatter-store instruction for half-precision +// needs being defined separately. +#undef GEMM_CCOL_GATHER_LOAD_FWD +#undef GEMM_CCOL_SCATTER_STORE_FWD + +#define GEMM_CCOL_GATHER_LOAD_FWD(ZFH,ZLH,ZIDX2,PT,CRS2,CADDR,CCS,CVSKIP,CTEMP) \ +" add x28, "#CADDR", "#CRS2" \n\t" \ +" ld1h z31.s, "#PT"/z, ["#CADDR", "#ZIDX2".s, uxtw #1] \n\t" \ +" ld1h "#ZFH".s, "#PT"/z, [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZFH".s, "#PT"/m, "#ZFH".s \n\t" \ +" fadd "#ZFH".h, "#ZFH".h, z31.h \n\t" \ +" add "#CTEMP", "#CADDR", "#CVSKIP" \n\t" \ +" add x28, "#CTEMP", "#CRS2" \n\t" \ +" ld1h z31.s, "#PT"/z, ["#CTEMP", "#ZIDX2".s, uxtw #1] \n\t" \ +" ld1h "#ZLH".s, "#PT"/z, [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZLH".s, "#PT"/m, "#ZLH".s \n\t" \ +" fadd "#ZLH".h, "#ZLH".h, z31.h \n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + +#define GEMM_CCOL_SCATTER_STORE_FWD(ZFH,ZLH,ZIDX2,PT,CRS2,CADDR,CCS,CVSKIP,CTEMP) \ +" add x28, "#CADDR", "#CRS2" \n\t" \ +" st1h "#ZFH".s, "#PT", ["#CADDR", "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZFH".s, "#PT"/m, "#ZFH".s \n\t" \ +" st1h "#ZFH".s, "#PT", [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" add "#CTEMP", "#CADDR", "#CVSKIP" \n\t" \ +" add x28, "#CTEMP", "#CRS2" \n\t" \ +" st1h "#ZLH".s, "#PT", ["#CTEMP", "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZLH".s, "#PT"/m, "#ZLH".s \n\t" \ +" st1h "#ZLH".s, "#PT", [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + + +void bli_shgemm_armsve_asm_2vx10_unindexed + ( + dim_t k0, + void* restrict alpha, + void* restrict a, + void* restrict b, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" inch x2, ALL, MUL #2 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x8, 0x3 \n\t" // Tag C address. +" lsl x8, x8, #56 \n\t" +" orr x5, x5, x8 \n\t" +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" \n\t" +" mov x8, #2 \n\t" // Multiply some address skips by sizeof(float16_t). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" ptrue p0.b \n\t" +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp x4, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" + +" ld1rh z20.h, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rh z21.h, p0/z, [x1, 2] \n\t" +" ld1rh z22.h, p0/z, [x1, 4] \n\t" +" ld1rh z23.h, p0/z, [x1, 6] \n\t" +" ld1rh z24.h, p0/z, [x1, 8] \n\t" +" ld1rh z25.h, p0/z, [x1, 10] \n\t" +" ld1rh z26.h, p0/z, [x1, 12] \n\t" +" ld1rh z27.h, p0/z, [x1, 14] \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp x6, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x4, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" subs x4, x4, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" add x0, x0, x2 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x8, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +" ld1rh z20.h, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rh z21.h, p0/z, [x1, 2] \n\t" +" ld1rh z22.h, p0/z, [x1, 4] \n\t" +" ld1rh z23.h, p0/z, [x1, 6] \n\t" +" ld1rh z24.h, p0/z, [x1, 8] \n\t" +" ld1rh z25.h, p0/z, [x1, 10] \n\t" +" ld1rh z26.h, p0/z, [x1, 12] \n\t" +" ld1rh z27.h, p0/z, [x1, 14] \n\t" +" ld1rh z28.h, p0/z, [x1, 16] \n\t" +" ld1rh z29.h, p0/z, [x1, 18] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x0, x0, x2 \n\t" // Forward A. +" add x1, x1, x3 \n\t" // Forward B. +" sub x8, x8, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1rh z30.h, p0/z, [x4] \n\t" // Load alpha & beta into vectors. +" ld1rh z31.h, p0/z, [x8] \n\t" +" fmov w4, h28 \n\t" // Copy alpha & beta to GP registers. +" fmov w8, h29 \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL2KEEP, [x0] \n\t" +" prfm PLDL2KEEP, [x0, 256*1] \n\t" +" prfm PLDL2KEEP, [x0, 256*2] \n\t" +" prfm PLDL2KEEP, [x0, 256*3] \n\t" +" prfm PLDL2KEEP, [x0, 256*4] \n\t" +" prfm PLDL2KEEP, [x0, 256*5] \n\t" +" prfm PLDL2KEEP, [x0, 256*6] \n\t" +" prfm PLDL2KEEP, [x0, 256*7] \n\t" +" prfm PLDL2KEEP, [x0, 256*8] \n\t" +" prfm PLDL2KEEP, [x0, 256*9] \n\t" +" prfm PLDL2KEEP, [x0, 256*10] \n\t" +" prfm PLDL2KEEP, [x0, 256*11] \n\t" +" prfm PLDL2KEEP, [x0, 256*12] \n\t" +" prfm PLDL2KEEP, [x0, 256*13] \n\t" +" prfm PLDL2KEEP, [x0, 256*14] \n\t" +" prfm PLDL2KEEP, [x0, 256*15] \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov h28, #1.0 \n\t" +" fmov w16, h28 \n\t" +" cmp w16, w4 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x5,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x10, xzr \n\t" +" incb x10 \n\t" +" madd x10, x10, x6, xzr \n\t" // C-column's logical 1-vector skip. +" mov x28, #2 \n\t" +" madd x6, x28, x6, xzr \n\t" // Double index skip for half-precision case. +" index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,x6,x9,x7,x10,x16) +" dup z31.h, w8 \n\t" // Restore beta destroyed by loading. +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,x6,x9,x7,x10,x16) +" \n\t" +" dup z31.h, w8 \n\t" // Restore beta destroyed by loading. +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,x6,x5,x7,x10,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,x6,x5,x7,x10,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16","x10","x28", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/old/bli_gemm_armsve_asm_z2vx7_unindexed.c b/kernels/armsve/3/old/bli_gemm_armsve_asm_z2vx7_unindexed.c new file mode 100644 index 0000000000..ca62f9db11 --- /dev/null +++ b/kernels/armsve/3/old/bli_gemm_armsve_asm_z2vx7_unindexed.c @@ -0,0 +1,274 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_dcomplex.h" + +// 2vx7 microkernels. +#include "armsve_asm_2vx7cmplx.h" + + +void bli_zgemm_armsve_asm_2vx7_unindexed + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + uint64_t mr = bli_vl_bytes_armsve() * 2 / 16; + GEMM_UKR_SETUP_CT( z, mr, 7, false ); + + __asm__ volatile ( +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #7 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #16 \n\t" // Multiply some address skips by sizeof(dcomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" ptrue p0.d \n\t" +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" \n\t" +" ld1rd z14.d, p0/z, [%1, 8*0] \n\t" // Load B's real & imaginary. +" ld1rd z15.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z16.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z17.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z18.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z19.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z20.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z21.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*7] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*9] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*11] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*13] \n\t" +" add %1, %1, x3 \n\t" +" \n\t" +" CCOL_PRFM: \n\t" +" cmp %3, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL14(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13) +" \n\t" +" cmp %5, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z28,z29,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z30,z31,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z28,z29,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z30,z31,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z30,z31,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp %6, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" ld1rd z14.d, p0/z, [%1, 8*0] \n\t" +" ld1rd z15.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z16.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z17.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z18.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z19.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z20.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z21.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*7] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*9] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*11] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*13] \n\t" +" add %1, %1, x3 \n\t" +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z28,z29,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" sub %6, %6, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rd z28.d, p0/z, [%7] \n\t" // Real(alpha). +" ld1rd z29.d, p0/z, [%7, 8] \n\t" // Imag(alpha). +" ld1rd z30.d, p0/z, [%8] \n\t" // Real(beta). +" ld1rd z31.d, p0/z, [%8, 8] \n\t" // Imag(beta). +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +GEMM_FMULCMPLX_COL7(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z28,z29) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +" cmp %3, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" +GEMM_CCMPLX_LOAD_COL7_C(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,p0,x9,%4) +GEMM_FMLACMPLX_COL7(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z30,z31) +GEMM_CCMPLX_STORE_COL7_C(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,%2,%4) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" +" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +" index z28.d, xzr, %3 \n\t" // s.t. 2*sizeof(double) = 2*8 = 16. +GEMM_CCMPLX_LOAD_COL7_G(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,p0,z28,x9,%4,x16) +GEMM_FMLACMPLX_COL7(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z30,z31) +GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z28,%2,%4,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_EXEC: \n\t" +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + + GEMM_UKR_FLUSH_CT( z ); +} + + diff --git a/kernels/armsve/3/old/bli_gemm_armsve_asm_z2vx8_unindexed.c b/kernels/armsve/3/old/bli_gemm_armsve_asm_z2vx8_unindexed.c new file mode 100644 index 0000000000..4a910baace --- /dev/null +++ b/kernels/armsve/3/old/bli_gemm_armsve_asm_z2vx8_unindexed.c @@ -0,0 +1,297 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_dcomplex.h" + +// 2vx8 microkernels. +#include "armsve_asm_2vx8cmplx.h" + +void bli_zgemm_armsve_asm_2vx8_unindexed + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k / 6; + uint64_t k_left = k % 6; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + uint64_t mr = bli_vl_bytes_armsve() * 2 / 16; + GEMM_UKR_SETUP_CT( z, mr, 8, false ); + + __asm__ volatile ( +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #8 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #16 \n\t" // Multiply some address skips by sizeof(dcomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" ptrue p0.d \n\t" +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Load B's real & half of imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +" ld1rd z28.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z29.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z30.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z31.d, p0/z, [%1, 8*7] \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp %3, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL16(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15) +" \n\t" +" cmp %5, #0 \n\t" // If no 6-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z18,z19,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z18,z19,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z18,z19,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp %6, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Reload B's real & half of imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +" ld1rd z28.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z29.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z30.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z31.d, p0/z, [%1, 8*7] \n\t" +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" sub %6, %6, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rd z16.d, p0/z, [%7] \n\t" // Real(alpha). +" ld1rd z17.d, p0/z, [%7, 8] \n\t" // Imag(alpha). +" ld1rd z18.d, p0/z, [%8] \n\t" // Real(beta). +" ld1rd z19.d, p0/z, [%8, 8] \n\t" // Imag(beta). +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +GEMM_FMULCMPLX_COL2(z20,z21,z22,z23,p0,z0 ,z1 ,z2 ,z3 ,z16,z17) +GEMM_FMULCMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z4 ,z5 ,z6 ,z7 ,z16,z17) +GEMM_FMULCMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z8 ,z9 ,z10,z11,z16,z17) +GEMM_FMULCMPLX_COL2(z8 ,z9 ,z10,z11,p0,z12,z13,z14,z15,z16,z17) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +" cmp %3, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z24,z25,z26,z27,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_C(z20,z21,z22,z23,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4) +" \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z24,z25,z26,z27,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" +" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +" index z16.d, xzr, %3 \n\t" // s.t. 2*sizeof(double) = 2*8 = 16. +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z16,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z24,z25,z26,z27,p0,z16,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z16,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z16,%2,%4,x16) +" \n\t" +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z16,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z24,z25,z26,z27,p0,z16,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z16,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_EXEC: \n\t" +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + + GEMM_UKR_FLUSH_CT( z ); +} + diff --git a/kernels/armsve/3/old/sup/bli_gemmsup_armsve_ref.c b/kernels/armsve/3/old/sup/bli_gemmsup_armsve_ref.c new file mode 100644 index 0000000000..ff3a35e7a6 --- /dev/null +++ b/kernels/armsve/3/old/sup/bli_gemmsup_armsve_ref.c @@ -0,0 +1,450 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// Separate instantiation for ArmSVE reference kernels. +// Temporary workaround. Will be removed after upstream has switched to a better way +// of exposing gemmsup interface. + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, _armsve, _ref2 ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, _armsve, _ref2 ) + diff --git a/kernels/armsve/3/old/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/old/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c new file mode 100644 index 0000000000..3341b63d00 --- /dev/null +++ b/kernels/armsve/3/old/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c @@ -0,0 +1,528 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" +#include + +// Double-precision composite instructions. +#include "../armsve_asm_macros_double.h" + +// 2vx10 microkernels. +#include "../armsve_asm_2vx10.h" + +// Prototype reference kernel. +GEMMSUP_KER_PROT( double, d, gemmsup_c_armsve_ref2 ) + +void __attribute__ ((noinline,optimize(0))) bli_dgemmsup_cv_armsve_2vx10_unindexed + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + static int called = 0; + if ( !called ) + { + fprintf(stderr, "rv called.\n"); + called = 1; + } + // c*c requires A to be stored in columns. + assert( rs_a0 == 1 ); + + dim_t n0_mker = n0 / 10; + dim_t n0_left = n0 % 10; + + if ( n0_left ) + { + // A[:, ::] + // B[::, n0_mker*10:n0] + // C[: , n0_mker*10:n0] + double *ai = a; + double *bi = b + n0_mker * 10 * cs_b0; + double *ci = c + n0_mker * 10 * cs_c0; + bli_dgemmsup_c_armsve_ref2 + ( + conja, conjb, + m0, n0_left, k0, + alpha, + ai, rs_a0, cs_a0, + bi, rs_b0, cs_b0, + beta, + ci, rs_c0, cs_c0, + data, + cntx + ); + } + // Return if it's a pure edge case. + if ( !n0_mker ) + return; + + // Determine VL. + uint64_t vlen2; + __asm__ ( + " mov x0, xzr \n\t" + " incd x0, ALL, MUL #2 \n\t" + " mov %[vlen2], x0 \n\t" + : [vlen2] "=r" (vlen2) + : + : "x0" + ); + + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t rs_a = 1; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t n_mker = n0_mker; + + dim_t m0_mker = m0 / vlen2; + dim_t m0_left = m0 % vlen2; + if ( m0_left ) + { + // Edge case on A side can be handled with one more (predicated) loop. + m0_mker++; + } else + m0_left = vlen2; + // uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_b = bli_auxinfo_ps_b( data ); + + for ( dim_t im0_mker = 0; im0_mker < m0_mker; ++im0_mker ) + { + uint64_t m_curr = vlen2; + if ( im0_mker == m0_mker - 1 ) + { + // Last m-loop. Maybe unnecessary. + m_curr = m0_left; + } + double *ai = a + im0_mker * vlen2 * rs_a0; + double *bi = b; + double *ci = c + im0_mker * vlen2 * rs_c0; + + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + __asm__ volatile ( +" ldr x0, %[bi] \n\t" +" ldr x1, %[rs_b] \n\t" // Row-skip of B. +" ldr x2, %[cs_b] \n\t" // Column-skip of B (element skip of B[l, :]). +" ldr x3, %[ps_b] \n\t" // Panel-skip (10*k) of B. +" ldr x4, %[cs_a] \n\t" // Column-Skip of A. +" \n\t" // Element skip of A[:, l] is guaranteed to be 1. +" ldr x5, %[ci] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr x5, x5, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr x0, x0, x16 \n\t" +#endif +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x1, x8, x1, xzr \n\t" // rs_b +" madd x2, x8, x2, xzr \n\t" // cs_b +" madd x3, x8, x3, xzr \n\t" // ps_b +" madd x4, x8, x4, xzr \n\t" // cs_a +" madd x7, x8, x7, xzr \n\t" // cs_c +" mov x8, #4 \n\t" +" madd x15, x8, x4, xzr \n\t" // Logical K=4 microkernel skip for A. +" \n\t" +#ifdef _A64FX +" mov x16, 0x20 \n\t" // Higher 6bit for Control#2: +" lsl x16, x16, #58 \n\t" // Valid|Strong|Strong|NoAlloc|Load|Strong +" orr x16, x16, x4 \n\t" // Stride. +" msr S3_3_C11_C6_2, x16 \n\t" // Write system register. +#endif +" \n\t" +" ldr x8, %[m_curr] \n\t" // Size of first dimension. +" mov x9, xzr \n\t" +" incd x9 \n\t" +" ptrue p0.d \n\t" +" whilelo p1.d, xzr, x8 \n\t" +" whilelo p2.d, x9, x8 \n\t" +" \n\t" +" ldr x8, %[n_mker] \n\t" // Number of N-loops. +" \n\t" +" ldr x20, %[ai] \n\t" // Parameters to be reloaded +" ldr x21, %[k_mker] \n\t" // within each millikernel loop. +" ldr x22, %[k_left] \n\t" +" ldr x23, %[alpha] \n\t" +" ldr x24, %[beta] \n\t" +" ldr x25, %[a_next] \n\t" +" ldr x26, %[b_next] \n\t" +" ldr x23, [x23] \n\t" // Directly load alpha and beta. +" ldr x24, [x24] \n\t" +" \n\t" +" MILLIKER_MLOOP: \n\t" +" \n\t" +" mov x11, x0 \n\t" // B's address. +// " ldr x10, %[ai] \n\t" // A's address. +" mov x10, x20 \n\t" +// " ldr x12, %[k_mker] \n\t" +" mov x12, x21 \n\t" +// " ldr x13, %[k_left] \n\t" +" mov x13, x22 \n\t" +#ifdef _A64FX +" mov x16, 0x3 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr x10, x10, x16 \n\t" +" mov x16, 0xa \n\t" // Control#2 for A address. +" lsl x16, x16, #60 \n\t" +" orr x10, x10, x16 \n\t" +#endif +" \n\t" +" cmp x12, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +" mov x14, x11 \n\t" +" ld1rd z20.d, p0/z, [x14] \n\t" // Load 8/10 of first B row. +" add x14, x14, x2 \n\t" +" ld1rd z21.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z22.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z23.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z24.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z25.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z26.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z27.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" prfm PLDL1KEEP, [x14] \n\t" // And prefetch the 2/10 left. +" add x14, x14, x2 \n\t" +" prfm PLDL1KEEP, [x14] \n\t" +" sub x14, x14, x2 \n\t" // Restore x14 to load edge. +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p1,p2,x10) +" add x16, x10, x4 \n\t" +" prfm PLDL1STRM, [x16] \n\t" // Prefetch 3/4 of A. +" add x16, x10, x4 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x10, x4 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" \n\t" +" CCOL_PRFM: \n\t" +" cmp x6, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x12, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z30,z31,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z28,z29,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z30,z31,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" \n\t" +" subs x12, x12, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z28,z29,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_G_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" add x10, x10, x4 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x13, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p1,p2,x10) +" mov x14, x11 \n\t" +" ld1rd z20.d, p0/z, [x14] \n\t" // Load 10/10 B. +" add x14, x14, x2 \n\t" +" ld1rd z21.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z22.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z23.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z24.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z25.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z26.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z27.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z28.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z29.d, p0/z, [x14] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x10, x10, x4 \n\t" // Forward A. +" add x11, x11, x1 \n\t" // Forward B. +" sub x13, x13, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x10, %[ai] \n\t" +" mov x10, x20 \n\t" +" add x11, x0, x3 \n\t" +" dup z30.d, x23 \n\t" // Broadcast alpha & beta into vectors. +" dup z31.d, x24 \n\t" +" \n\t" +" cmp x8, #1 \n\t" +" b.eq PREFETCH_ABNEXT \n\t" +" prfm PLDL1STRM, [x10] \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" b WRITE_MEM \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x1, %[a_next] \n\t" // Final Millikernel loop, x1 and x2 not needed. +" mov x1, x25 \n\t" +// " ldr x2, %[b_next] \n\t" +" mov x2, x26 \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" prfm PLDL2KEEP, [x1, 256*10] \n\t" +" prfm PLDL2KEEP, [x1, 256*11] \n\t" +" prfm PLDL2KEEP, [x1, 256*12] \n\t" +" prfm PLDL2KEEP, [x1, 256*13] \n\t" +" prfm PLDL2KEEP, [x1, 256*14] \n\t" +" prfm PLDL2KEEP, [x1, 256*15] \n\t" +" prfm PLDL2KEEP, [x2] \n\t" +" prfm PLDL2KEEP, [x2, 256*1] \n\t" +" prfm PLDL2KEEP, [x2, 256*2] \n\t" +" prfm PLDL2KEEP, [x2, 256*3] \n\t" +" prfm PLDL2KEEP, [x2, 256*4] \n\t" +" prfm PLDL2KEEP, [x2, 256*5] \n\t" +" prfm PLDL2KEEP, [x2, 256*6] \n\t" +" prfm PLDL2KEEP, [x2, 256*7] \n\t" +" prfm PLDL2KEEP, [x2, 256*8] \n\t" +" prfm PLDL2KEEP, [x2, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" +" fmov x16, d28 \n\t" +" cmp x16, x23 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +" mov x13, xzr \n\t" // C-column's physical 1-vector skip. +" incb x13 \n\t" +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x5,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x5,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x12, xzr \n\t" +" incb x12 \n\t" +" madd x13, x12, x6, xzr \n\t" // C-column's logical 1-vector skip. +" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x9,x7,x13,x16) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x9,x7,x13,x16) +" \n\t" +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x5,x7,x13,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x5,x7,x13,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" subs x8, x8, #1 \n\t" +" b.eq END_EXEC \n\t" +" \n\t" // Address of C already forwarded to next column. +" add x0, x0, x3 \n\t" // Forward B's base address to the next logic panel. +" b MILLIKER_MLOOP \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [bi] "m" (bi), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b] "m" (ps_b), + [cs_a] "m" (cs_a), + [ci] "m" (ci), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [m_curr] "m" (m_curr), + [n_mker] "m" (n_mker), + [ai] "m" (ai), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x10","x11","x12","x13","x14","x15","x16","x17", + "x20","x21","x22","x23","x24","x25","x26", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + } +} + +void bli_dgemmsup_rv_armsve_10x2v_unindexed + ( + conj_t conjat, + conj_t conjbt, + dim_t m0t, + dim_t n0t, + dim_t k0, + double* restrict alpha, + double* restrict at, inc_t rs_at0, inc_t cs_at0, + double* restrict bt, inc_t rs_bt0, inc_t cs_bt0, + double* restrict beta, + double* restrict ct, inc_t rs_ct0, inc_t cs_ct0, + auxinfo_t* restrict datat, + cntx_t* restrict cntx + ) +{ + auxinfo_t data; + bli_auxinfo_set_next_a( bli_auxinfo_next_b( datat ), &data ); + bli_auxinfo_set_next_b( bli_auxinfo_next_a( datat ), &data ); + bli_auxinfo_set_ps_a( bli_auxinfo_ps_b( datat ), &data ); + bli_auxinfo_set_ps_b( bli_auxinfo_ps_a( datat ), &data ); + bli_dgemmsup_cv_armsve_2vx10_unindexed + ( + conjbt, conjat, + n0t, m0t, k0, + alpha, + bt, cs_bt0, rs_bt0, + at, cs_at0, rs_at0, + beta, + ct, cs_ct0, rs_ct0, + &data, + cntx + ); +} + diff --git a/kernels/armsve/3/old/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/old/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c new file mode 100644 index 0000000000..6bcea73f5d --- /dev/null +++ b/kernels/armsve/3/old/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c @@ -0,0 +1,412 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" +#include + +// Double-precision composite instructions. +#include "../armsve_asm_macros_double.h" + +// 2vx10 microkernels. +#include "../armsve_asm_2vx10.h" + +// Prototype reference kernel. +GEMMSUP_KER_PROT( double, d, gemmsup_r_armsve_ref2 ) + +void __attribute__ ((optimize(0))) bli_dgemmsup_rv_armsve_2vx10_unindexed + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + static int called = 0; + if ( !called ) + { + fprintf(stderr, "rv called.\n"); + called = 1; + } + // r*r requires B to be stored in rows. + assert(cs_b0 == 1); + + dim_t n0_mker = n0 / 10; + dim_t n0_left = n0 % 10; + + if ( n0_left ) + { + // A[:, ::] + // B[::, n0_mker*10:n0] + // C[: , n0_mker*10:n0] + double *ai = a; + double *bi = b + n0_mker * 10 * cs_b0; + double *ci = c + n0_mker * 10 * cs_c0; + bli_dgemmsup_r_armsve_ref2 + ( + conja, conjb, + m0, n0_left, k0, + alpha, + ai, rs_a0, cs_a0, + bi, rs_b0, cs_b0, + beta, + ci, rs_c0, cs_c0, + data, + cntx + ); + } + // Return if it's a pure edge case. + if ( !n0_mker ) + return; + + // Determine VL. + uint64_t vlen2; + __asm__ ( + " mov x0, xzr \n\t" + " incd x0, ALL, MUL #2 \n\t" + " mov %[vlen2], x0 \n\t" + : [vlen2] "=r" (vlen2) + : + : "x0" + ); + + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + // uint64_t cs_b = 1; + + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t m_mker = m0 / vlen2; + uint64_t m_left = m0 % vlen2; + if ( m_left ) + { + // Edge case on A side can be handled with one more (predicated) loop. + m_mker++; + } else + m_left = vlen2; + uint64_t ps_a = bli_auxinfo_ps_a( data ); + // uint64_t ps_b = bli_auxinfo_ps_b( data ); + + for ( dim_t in0_mker = 0; in0_mker < n0_mker; ++in0_mker ) + { + double *ai = a; + double *bi = b + in0_mker * 10 * cs_b0; + double *ci = c + in0_mker * 10 * cs_c0; + + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + __asm__ volatile ( +" ldr x0, %[ai] \n\t" +" ldr x1, %[rs_a] \n\t" // Row-skip of A (element skip of A[:, l]). +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x3, %[ps_a] \n\t" // Panel-skip (vlen2*k) of A. +" ldr x4, %[rs_b] \n\t" // Row-Skip of B. +" \n\t" // Element skip of B[l, :] is guaranteed to be 1. +" ldr x5, %[ci] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr x5, x5, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr x0, x0, x16 \n\t" +#endif +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // ps_a +" madd x4, x8, x4, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" mov x8, xzr \n\t" +" incb x8 \n\t" +" madd x14, x8, x1, xzr \n\t" // A-column's logical 1-vector skip. +" mov x8, #4 \n\t" +" madd x15, x8, x2, xzr \n\t" // Logical K=4 microkernel skip for A. +// " mov x8, #4 \n\t" +// " madd x17, x8, x4, xzr \n\t" // Logical K=4 microkernel skip for B. +" \n\t" +" ldr x8, %[m_mker] \n\t" // Number of M-loops. +" ptrue p0.d \n\t" +" ptrue p1.d \n\t" +" ptrue p2.d \n\t" +" \n\t" +" MILLIKER_MLOOP: \n\t" +" \n\t" +" cmp x8, #1 \n\t" +" b.ne UKER_BEGIN \n\t" +" \n\t" +" ldr x10, %[m_left] \n\t" // Final (incomplete) millikernel loop. +" mov x11, xzr \n\t" +" incd x11 \n\t" +" whilelo p1.d, xzr, x10 \n\t" // Overwrite p1/p2. +" whilelo p2.d, x11, x10 \n\t" +" \n\t" +" UKER_BEGIN: \n\t" +" mov x10, x0 \n\t" // A's address. +" ldr x11, %[bi] \n\t" // B's address. +" ldr x12, %[k_mker] \n\t" +" ldr x13, %[k_left] \n\t" +#ifdef _A64FX +" mov x16, 0x3 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr x11, x11, x16 \n\t" +#endif +" \n\t" +" mov x16, x11 \n\t" // Prefetch first kernel of B. +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" \n\t" +" ld1rd z20.d, p0/z, [x11] \n\t" // (Partial) first B row. +" ld1rd z21.d, p0/z, [x11, #8] \n\t" +" ld1rd z22.d, p0/z, [x11, #16] \n\t" +" ld1rd z23.d, p0/z, [x11, #24] \n\t" +" ld1rd z24.d, p0/z, [x11, #32] \n\t" +" ld1rd z25.d, p0/z, [x11, #40] \n\t" +" ld1rd z26.d, p0/z, [x11, #48] \n\t" +" ld1rd z27.d, p0/z, [x11, #56] \n\t" +" \n\t" +" index z29.d, xzr, x1 \n\t" // First A column. +" \n\t" // Skips passed to index is not multiplied by 8. +GEMM_ACOL_GATHER_LOAD(z28,z29,z29,p1,p2,x10,x14,x16) +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x12, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" // Unroll the 4-loop. +" \n\t" +" index z31.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z30,z31,z31,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" \n\t" +" index z29.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z28,z29,z29,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" \n\t" +" index z31.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z30,z31,z31,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" \n\t" +" subs x12, x12, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +" index z29.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z28,z29,z29,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" add x10, x10, x2 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x13, #0 \n\t" +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +" index z31.d, xzr, x1 \n\t" +GEMM_ACOL_GATHER_LOAD(z30,z31,z31,p1,p2,x10,x14,x16) +" ld1rd z20.d, p0/z, [x11] \n\t" +" ld1rd z21.d, p0/z, [x11, #8] \n\t" +" ld1rd z22.d, p0/z, [x11, #16] \n\t" +" ld1rd z23.d, p0/z, [x11, #24] \n\t" +" ld1rd z24.d, p0/z, [x11, #32] \n\t" +" ld1rd z25.d, p0/z, [x11, #40] \n\t" +" ld1rd z26.d, p0/z, [x11, #48] \n\t" +" ld1rd z27.d, p0/z, [x11, #56] \n\t" +" ld1rd z28.d, p0/z, [x11, #64] \n\t" +" ld1rd z29.d, p0/z, [x11, #72] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x10, x10, x2 \n\t" // Forward A. +" add x11, x11, x4 \n\t" // Forward B. +" sub x13, x13, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +" ldr x11, %[bi] \n\t" +" ldr x12, %[alpha] \n\t" // Load alpha & beta. +" ldr x13, %[beta] \n\t" +" ld1rd z30.d, p0/z, [x12] \n\t" +" ld1rd z31.d, p0/z, [x13] \n\t" +" ldr x12, [x12] \n\t" +" \n\t" +" cmp x8, #1 \n\t" +" b.eq PREFETCH_ABNEXT \n\t" +" prfm PLDL2STRM, [x11] \n\t" +" b WRITE_MEM \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +" ldr x1, %[a_next] \n\t" // Final Millikernel loop, x1 and x2 not needed. +" ldr x2, %[b_next] \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" prfm PLDL2KEEP, [x1, 256*10] \n\t" +" prfm PLDL2KEEP, [x1, 256*11] \n\t" +" prfm PLDL2KEEP, [x1, 256*12] \n\t" +" prfm PLDL2KEEP, [x1, 256*13] \n\t" +" prfm PLDL2KEEP, [x1, 256*14] \n\t" +" prfm PLDL2KEEP, [x1, 256*15] \n\t" +" prfm PLDL2KEEP, [x2] \n\t" +" prfm PLDL2KEEP, [x2, 256*1] \n\t" +" prfm PLDL2KEEP, [x2, 256*2] \n\t" +" prfm PLDL2KEEP, [x2, 256*3] \n\t" +" prfm PLDL2KEEP, [x2, 256*4] \n\t" +" prfm PLDL2KEEP, [x2, 256*5] \n\t" +" prfm PLDL2KEEP, [x2, 256*6] \n\t" +" prfm PLDL2KEEP, [x2, 256*7] \n\t" +" prfm PLDL2KEEP, [x2, 256*8] \n\t" +" prfm PLDL2KEEP, [x2, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" +" fmov x16, d28 \n\t" +" cmp x16, x12 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" mov x10, x5 \n\t" // C address for storing. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +" mov x13, xzr \n\t" // C-column's physical 1-vector skip. +" incb x13 \n\t" +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x10,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x10,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x12, xzr \n\t" +" incb x12 \n\t" +" madd x13, x12, x6, xzr \n\t" // C-column's logical 1-vector skip. +" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x9,x7,x13,x16) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x9,x7,x13,x16) +" \n\t" +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x10,x7,x13,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x10,x7,x13,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" subs x8, x8, #1 \n\t" +" b.eq END_EXEC \n\t" +" \n\t" +" add x0, x0, x3 \n\t" // Forward A's base address to the next logic panel. +" add x5, x5, x13 \n\t" // Forward C's base address to the next logic panel. +" add x5, x5, x13 \n\t" +" b MILLIKER_MLOOP \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [ai] "m" (ai), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [ci] "m" (ci), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [m_mker] "m" (m_mker), + [m_left] "m" (m_left), + [bi] "m" (bi), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x10","x11","x12","x13","x14","x15","x16",//"x17", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + } +} + diff --git a/kernels/armsve/bli_kernels_armsve.h b/kernels/armsve/bli_kernels_armsve.h new file mode 100644 index 0000000000..00e1f04557 --- /dev/null +++ b/kernels/armsve/bli_kernels_armsve.h @@ -0,0 +1,53 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "./3/bli_armsve_utils.h" + +// GEMM_UKR_PROT( double, d, gemm_armsve256_asm_8x8 ) +GEMM_UKR_PROT( double, d, gemm_armsve_asm_2vx10_unindexed ) +GEMM_UKR_PROT( float, s, gemm_armsve_asm_2vx10_unindexed ) +GEMM_UKR_PROT( scomplex, c, gemm_armsve_asm_2vx10_unindexed ) +GEMM_UKR_PROT( dcomplex, z, gemm_armsve_asm_2vx10_unindexed ) +// GEMM_UKR_PROT( dcomplex, z, gemm_armsve_asm_2vx8_unindexed ) +// GEMM_UKR_PROT( dcomplex, z, gemm_armsve_asm_2vx7_unindexed ) +//GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_2vx10_unindexed ) +//GEMMSUP_KER_PROT( double, d, gemmsup_cv_armsve_2vx10_unindexed ) +//GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_10x2v_unindexed ) + +// Use SVE intrinsics only for referred cases. +#if !defined(BLIS_FAMILY_A64FX) +PACKM_KER_PROT( double, d, packm_armsve256_int_8xk ) +PACKM_KER_PROT( double, d, packm_armsve512_int_12xk ) +#endif +PACKM_KER_PROT( double, d, packm_armsve512_asm_16xk ) +PACKM_KER_PROT( double, d, packm_armsve512_asm_10xk ) diff --git a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c index b526cd0951..c248285c38 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c @@ -48,23 +48,23 @@ void bli_sgemm_armv7a_ker_4x4 void bli_sgemm_armv7a_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, float* restrict beta, - float* restrict c, inc_t rs_c0, inc_t cs_c0, + float* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( s, 4, 4, false ); bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( s ); } @@ -83,23 +83,23 @@ void bli_dgemm_armv7a_ker_4x4 void bli_dgemm_armv7a_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false ); bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( d ); } @@ -118,23 +118,23 @@ void bli_cgemm_armv7a_ker_2x2 void bli_cgemm_armv7a_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( c, 2, 2, false ); bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( c ); } @@ -153,22 +153,22 @@ void bli_zgemm_armv7a_ker_2x2 void bli_zgemm_armv7a_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, dcomplex* restrict beta, - dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( z, 2, 2, false ); bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c index 528703f0b5..06f36a3463 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c @@ -37,7 +37,9 @@ void bli_sgemm_armv7a_int_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -49,12 +51,14 @@ void bli_sgemm_armv7a_int_4x4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k_iter = k0 / 4; - uint32_t k_left = k0 % 4; + uint32_t k_iter = k / 4; + uint32_t k_left = k % 4; uint32_t rs_c = rs_c0; uint32_t cs_c = cs_c0; uint32_t i; + GEMM_UKR_SETUP_CT( s, 4, 4, false ); + void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); @@ -80,46 +84,26 @@ void bli_sgemm_armv7a_int_4x4 // Vector for column 3 float32x4_t cv3; - if( rs_c == 1 ) + if ( *beta != 0.0F ) { // Load column 0 - cv0 = vld1q_f32( c + 0*rs_c + 0*cs_c ); - + cv0 = vld1q_f32( c + 0*cs_c ); + // Load column 1 - cv1 = vld1q_f32( c + 0*rs_c + 1*cs_c ); - + cv1 = vld1q_f32( c + 1*cs_c ); + // Load column 2 - cv2 = vld1q_f32( c + 0*rs_c + 2*cs_c ); - + cv2 = vld1q_f32( c + 2*cs_c ); + // Load column 3 - cv3 = vld1q_f32( c + 0*rs_c + 3*cs_c ); - } + cv3 = vld1q_f32( c + 3*cs_c ); + } else { - // Load column 0 - cv0 = vld1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0); - cv0 = vld1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1); - cv0 = vld1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2); - cv0 = vld1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3); - - // Load column 1 - cv1 = vld1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0); - cv1 = vld1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1); - cv1 = vld1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2); - cv1 = vld1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3); - - // Load column 2 - cv2 = vld1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0); - cv2 = vld1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1); - cv2 = vld1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2); - cv2 = vld1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3); - - // Load column 3 - cv3 = vld1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0); - cv3 = vld1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1); - cv3 = vld1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2); - cv3 = vld1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3); - + cv0 = vmovq_n_f32( 0.0 ); + cv1 = vmovq_n_f32( 0.0 ); + cv2 = vmovq_n_f32( 0.0 ); + cv3 = vmovq_n_f32( 0.0 ); } // Vector for accummulating column 0 @@ -142,15 +126,15 @@ void bli_sgemm_armv7a_int_4x4 // Initialize vector to 0.0 abv3 = vmovq_n_f32( 0.0 ); - for ( i = 0; i < k_iter; ++i ) - { + for ( i = 0; i < k_iter; ++i ) + { // Begin iter 0 - av1 = vld1q_f32( a ); + av1 = vld1q_f32( a ); __builtin_prefetch( a + 224 ); __builtin_prefetch( b + 224 ); - - bv1 = vld1q_f32( b ); + + bv1 = vld1q_f32( b ); abv0 = vmlaq_lane_f32( abv0, av1, vget_low_f32(bv1), 0 ); abv1 = vmlaq_lane_f32( abv1, av1, vget_low_f32(bv1), 1 ); @@ -158,24 +142,24 @@ void bli_sgemm_armv7a_int_4x4 abv3 = vmlaq_lane_f32( abv3, av1, vget_high_f32(bv1), 1 ); - av2 = vld1q_f32( a+4 ); + av2 = vld1q_f32( a+4 ); //__builtin_prefetch( a + 116 ); //__builtin_prefetch( b + 116 ); - - bv2 = vld1q_f32( b+4 ); + + bv2 = vld1q_f32( b+4 ); abv0 = vmlaq_lane_f32( abv0, av2, vget_low_f32(bv2), 0 ); abv1 = vmlaq_lane_f32( abv1, av2, vget_low_f32(bv2), 1 ); abv2 = vmlaq_lane_f32( abv2, av2, vget_high_f32(bv2), 0 ); abv3 = vmlaq_lane_f32( abv3, av2, vget_high_f32(bv2), 1 ); - av3 = vld1q_f32( a+8 ); + av3 = vld1q_f32( a+8 ); //__builtin_prefetch( a + 120 ); //__builtin_prefetch( b + 120 ); - - bv3 = vld1q_f32( b+8 ); + + bv3 = vld1q_f32( b+8 ); abv0 = vmlaq_lane_f32( abv0, av3, vget_low_f32(bv3), 0 ); abv1 = vmlaq_lane_f32( abv1, av3, vget_low_f32(bv3), 1 ); @@ -183,12 +167,12 @@ void bli_sgemm_armv7a_int_4x4 abv3 = vmlaq_lane_f32( abv3, av3, vget_high_f32(bv3), 1 ); - av4 = vld1q_f32( a+12); + av4 = vld1q_f32( a+12); //__builtin_prefetch( a + 124 ); //__builtin_prefetch( b + 124 ); - - bv4 = vld1q_f32( b+12); + + bv4 = vld1q_f32( b+12); abv0 = vmlaq_lane_f32( abv0, av4, vget_low_f32(bv4), 0 ); abv1 = vmlaq_lane_f32( abv1, av4, vget_low_f32(bv4), 1 ); @@ -197,81 +181,70 @@ void bli_sgemm_armv7a_int_4x4 - a += 16; - b += 16; - } + a += 16; + b += 16; + } - for ( i = 0; i < k_left; ++i ) - { - av1 = vld1q_f32( a ); + for ( i = 0; i < k_left; ++i ) + { + av1 = vld1q_f32( a ); __builtin_prefetch( a + 112 ); __builtin_prefetch( b + 112 ); - - bv1 = vld1q_f32( b ); + + bv1 = vld1q_f32( b ); abv0 = vmlaq_lane_f32( abv0, av1, vget_low_f32(bv1), 0 ); abv1 = vmlaq_lane_f32( abv1, av1, vget_low_f32(bv1), 1 ); abv2 = vmlaq_lane_f32( abv2, av1, vget_high_f32(bv1), 0 ); abv3 = vmlaq_lane_f32( abv3, av1, vget_high_f32(bv1), 1 ); - a += 4; - b += 4; + a += 4; + b += 4; } __builtin_prefetch( a_next ); __builtin_prefetch( b_next ); - cv0 = vmulq_n_f32( cv0, *beta ); - cv1 = vmulq_n_f32( cv1, *beta ); - cv2 = vmulq_n_f32( cv2, *beta ); - cv3 = vmulq_n_f32( cv3, *beta ); - - cv0 = vmlaq_f32( cv0, abv0, alphav ); - cv1 = vmlaq_f32( cv1, abv1, alphav ); - cv2 = vmlaq_f32( cv2, abv2, alphav ); - cv3 = vmlaq_f32( cv3, abv3, alphav ); - - if( rs_c == 1 ) + if ( *beta != 0.0F ) { - // Store column 0 - vst1q_f32( c + 0*rs_c + 0*cs_c, cv0 ); - // Store column 1 - vst1q_f32( c + 0*rs_c + 1*cs_c, cv1 ); - // Store column 2 - vst1q_f32( c + 0*rs_c + 2*cs_c, cv2 ); - // Store column 3 - vst1q_f32( c + 0*rs_c + 3*cs_c, cv3 ); + // Multiply C by beta and then accumulate alpha * A * B. + cv0 = vmulq_n_f32( cv0, *beta ); + cv1 = vmulq_n_f32( cv1, *beta ); + cv2 = vmulq_n_f32( cv2, *beta ); + cv3 = vmulq_n_f32( cv3, *beta ); + + cv0 = vmlaq_f32( cv0, abv0, alphav ); + cv1 = vmlaq_f32( cv1, abv1, alphav ); + cv2 = vmlaq_f32( cv2, abv2, alphav ); + cv3 = vmlaq_f32( cv3, abv3, alphav ); } - else{ - // Store column 0 - vst1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0); - vst1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1); - vst1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2); - vst1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3); - - // Store column 1 - vst1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0); - vst1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1); - vst1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2); - vst1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3); - - // Store column 2 - vst1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0); - vst1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1); - vst1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2); - vst1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3); - - // Store column 3 - vst1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0); - vst1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1); - vst1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2); - vst1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3); + else + { + // Since beta = 0, skip straight to accumulating alpha * A * B. + // Note: C (cv?) was initialized to zero above. + cv0 = vmlaq_f32( cv0, abv0, alphav ); + cv1 = vmlaq_f32( cv1, abv1, alphav ); + cv2 = vmlaq_f32( cv2, abv2, alphav ); + cv3 = vmlaq_f32( cv3, abv3, alphav ); } + + // Store column 0 + vst1q_f32( c + 0*cs_c, cv0 ); + // Store column 1 + vst1q_f32( c + 1*cs_c, cv1 ); + // Store column 2 + vst1q_f32( c + 2*cs_c, cv2 ); + // Store column 3 + vst1q_f32( c + 3*cs_c, cv3 ); + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_armv7a_int_4x4 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -290,6 +263,8 @@ void bli_dgemm_armv7a_int_4x4 uint32_t cs_c = cs_c0; uint32_t i; + GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false ); + //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); @@ -306,53 +281,53 @@ void bli_dgemm_armv7a_int_4x4 double b0, b1, b2, b3; double B0, B1, B2, B3; - double ab00, ab01, ab02, ab03; - double ab10, ab11, ab12, ab13; + double ab00, ab01, ab02, ab03; + double ab10, ab11, ab12, ab13; double ab20, ab21, ab22, ab23; - double ab30, ab31, ab32, ab33; + double ab30, ab31, ab32, ab33; - double* restrict c00, * restrict c01, * restrict c02, * restrict c03; + double* restrict c00, * restrict c01, * restrict c02, * restrict c03; double* restrict c10, * restrict c11, * restrict c12, * restrict c13; double* restrict c20, * restrict c21, * restrict c22, * restrict c23; - double* restrict c30, * restrict c31, * restrict c32, * restrict c33; + double* restrict c30, * restrict c31, * restrict c32, * restrict c33; double* restrict ap = a; - double* restrict bp = b; + double* restrict bp = b; double* restrict Ap = a + 4; - double* restrict Bp = b + 4; + double* restrict Bp = b + 4; - c00 = (c + 0*rs_c + 0*cs_c); - c10 = (c + 1*rs_c + 0*cs_c); - c20 = (c + 2*rs_c + 0*cs_c); - c30 = (c + 3*rs_c + 0*cs_c); + c00 = (c + 0*rs_c + 0*cs_c); + c10 = (c + 1*rs_c + 0*cs_c); + c20 = (c + 2*rs_c + 0*cs_c); + c30 = (c + 3*rs_c + 0*cs_c); - c01 = (c + 0*rs_c + 1*cs_c); - c11 = (c + 1*rs_c + 1*cs_c); - c21 = (c + 2*rs_c + 1*cs_c); - c31 = (c + 3*rs_c + 1*cs_c); + c01 = (c + 0*rs_c + 1*cs_c); + c11 = (c + 1*rs_c + 1*cs_c); + c21 = (c + 2*rs_c + 1*cs_c); + c31 = (c + 3*rs_c + 1*cs_c); - c02 = (c + 0*rs_c + 2*cs_c); - c12 = (c + 1*rs_c + 2*cs_c); - c22 = (c + 2*rs_c + 2*cs_c); - c32 = (c + 3*rs_c + 2*cs_c); + c02 = (c + 0*rs_c + 2*cs_c); + c12 = (c + 1*rs_c + 2*cs_c); + c22 = (c + 2*rs_c + 2*cs_c); + c32 = (c + 3*rs_c + 2*cs_c); - c03 = (c + 0*rs_c + 3*cs_c); - c13 = (c + 1*rs_c + 3*cs_c); - c23 = (c + 2*rs_c + 3*cs_c); - c33 = (c + 3*rs_c + 3*cs_c); + c03 = (c + 0*rs_c + 3*cs_c); + c13 = (c + 1*rs_c + 3*cs_c); + c23 = (c + 2*rs_c + 3*cs_c); + c33 = (c + 3*rs_c + 3*cs_c); ab00 = 0.0; ab10 = 0.0; ab20 = 0.0; ab30 = 0.0; ab01 = 0.0; ab11 = 0.0; ab21 = 0.0; ab31 = 0.0; ab02 = 0.0; ab12 = 0.0; ab22 = 0.0; ab32 = 0.0; ab03 = 0.0; ab13 = 0.0; ab23 = 0.0; ab33 = 0.0; - A0 = *(Ap + 0); - A1 = *(Ap + 1); - A2 = *(Ap + 2); - A3 = *(Ap + 3); + A0 = *(Ap + 0); + A1 = *(Ap + 1); + A2 = *(Ap + 2); + A3 = *(Ap + 3); - a0 = *(ap + 0); + a0 = *(ap + 0); a1 = *(ap + 1); a2 = *(ap + 2); @@ -365,11 +340,11 @@ void bli_dgemm_armv7a_int_4x4 b1 = *(bp + 1); b2 = *(bp + 2); - double *Aplast = (Ap + 4*(k-k_left)); + double *Aplast = (Ap + 4*(k-k_left)); //for ( i = 0; i < k_iter; ++i ) // Unroll by factor 4. for ( ; Ap != Aplast ; ) // Unroll by factor 4. - { + { /* Prefetch */ //__asm__ ("pld\t[%0],#100\n\t" : :"r"(Ap) : ); __builtin_prefetch( ap + 112 ); @@ -428,7 +403,7 @@ void bli_dgemm_armv7a_int_4x4 b2 = *(bp + 10); ab03 += a0 * b3; - a0 = *(ap + 8); + a0 = *(ap + 8); ab13 += a1 * b3; a1 = *(ap + 9); ab23 += a2 * b3; @@ -436,17 +411,17 @@ void bli_dgemm_armv7a_int_4x4 ab33 += a3 * b3; //a3 = *(ap + 11); - ap += 8; - Ap += 8; - bp += 8; - Bp += 8; + ap += 8; + Ap += 8; + bp += 8; + Bp += 8; - } + } - for ( i = 0; i < k_left; ++i ) - { - a0 = *(ap + 0); + for ( i = 0; i < k_left; ++i ) + { + a0 = *(ap + 0); a1 = *(ap + 1); a2 = *(ap + 2); a3 = *(ap + 3); @@ -476,48 +451,75 @@ void bli_dgemm_armv7a_int_4x4 ab23 += a2 * b3; ab33 += a3 * b3; - ap += 4; - bp += 4; - } - - *c00 = *c00 * *beta; - *c10 = *c10 * *beta; - *c20 = *c20 * *beta; - *c30 = *c30 * *beta; - - *c01 = *c01 * *beta; - *c11 = *c11 * *beta; - *c21 = *c21 * *beta; - *c31 = *c31 * *beta; - - *c02 = *c02 * *beta; - *c12 = *c12 * *beta; - *c22 = *c22 * *beta; - *c32 = *c32 * *beta; - - *c03 = *c03 * *beta; - *c13 = *c13 * *beta; - *c23 = *c23 * *beta; - *c33 = *c33 * *beta; - - *c00 += ab00 * *alpha; - *c10 += ab10 * *alpha; - *c20 += ab20 * *alpha; - *c30 += ab30 * *alpha; - - *c01 += ab01 * *alpha; - *c11 += ab11 * *alpha; - *c21 += ab21 * *alpha; - *c31 += ab31 * *alpha; - - *c02 += ab02 * *alpha; - *c12 += ab12 * *alpha; - *c22 += ab22 * *alpha; - *c32 += ab32 * *alpha; - - *c03 += ab03 * *alpha; - *c13 += ab13 * *alpha; - *c23 += ab23 * *alpha; - *c33 += ab33 * *alpha; + ap += 4; + bp += 4; + } + + if ( *beta == 0.0 ) + { + *c00 = ab00 * *alpha; + *c10 = ab10 * *alpha; + *c20 = ab20 * *alpha; + *c30 = ab30 * *alpha; + + *c01 = ab01 * *alpha; + *c11 = ab11 * *alpha; + *c21 = ab21 * *alpha; + *c31 = ab31 * *alpha; + + *c02 = ab02 * *alpha; + *c12 = ab12 * *alpha; + *c22 = ab22 * *alpha; + *c32 = ab32 * *alpha; + + *c03 = ab03 * *alpha; + *c13 = ab13 * *alpha; + *c23 = ab23 * *alpha; + *c33 = ab33 * *alpha; + } + else + { + *c00 = *c00 * *beta; + *c10 = *c10 * *beta; + *c20 = *c20 * *beta; + *c30 = *c30 * *beta; + + *c01 = *c01 * *beta; + *c11 = *c11 * *beta; + *c21 = *c21 * *beta; + *c31 = *c31 * *beta; + + *c02 = *c02 * *beta; + *c12 = *c12 * *beta; + *c22 = *c22 * *beta; + *c32 = *c32 * *beta; + + *c03 = *c03 * *beta; + *c13 = *c13 * *beta; + *c23 = *c23 * *beta; + *c33 = *c33 * *beta; + + *c00 += ab00 * *alpha; + *c10 += ab10 * *alpha; + *c20 += ab20 * *alpha; + *c30 += ab30 * *alpha; + + *c01 += ab01 * *alpha; + *c11 += ab11 * *alpha; + *c21 += ab21 * *alpha; + *c31 += ab31 * *alpha; + + *c02 += ab02 * *alpha; + *c12 += ab12 * *alpha; + *c22 += ab22 * *alpha; + *c32 += ab32 * *alpha; + + *c03 += ab03 * *alpha; + *c13 += ab13 * *alpha; + *c23 += ab23 * *alpha; + *c33 += ab33 * *alpha; + } + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_d6xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_d6xk.c new file mode 100644 index 0000000000..301b8ad790 --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_d6xk.c @@ -0,0 +1,323 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_2 _Pragma("unroll 2") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_2 _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_2 +#endif + +void bli_dpackm_armv8a_int_6xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + double* restrict kappa, + double* restrict a, inc_t inca0, inc_t lda0, + double* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 6; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 2; + uint64_t k_left = k0 % 2; + double* a_loc = a; + double* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_deq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float64x2_t vkappa = vld1q_dup_f64( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + v4 = vmulq_f64( v4, vkappa ); + v5 = vmulq_f64( v5, vkappa ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(dscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + +//bli_dfprintm( stdout, "packm 6xk ker: a_packed", cdim0, k0_max, p, 1, ldp0, "%5.2f", "" ); + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_d8xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_d8xk.c new file mode 100644 index 0000000000..321fa5403b --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_d8xk.c @@ -0,0 +1,353 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_2 _Pragma("unroll 2") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_2 _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_2 +#endif + +void bli_dpackm_armv8a_int_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + double* restrict kappa, + double* restrict a, inc_t inca0, inc_t lda0, + double* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 8; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 2; + uint64_t k_left = k0 % 2; + double* a_loc = a; + double* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_deq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + float64x2_t v3 = vld1q_f64( a_loc + 6 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v6 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v7 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + v6 = vld1q_f64( a_loc + inca * 6 ); + v7 = vld1q_f64( a_loc + inca * 7 ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd3_1 = vtrn1q_f64( v6, v7 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + float64x2_t vd3_2 = vtrn2q_f64( v6, v7 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + vst1q_f64( p_loc + 6, vd3_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + vst1q_f64( p_loc + 6, vd3_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + v3 = vld1q_lane_f64( a_loc + inca * 6, v3, 0 ); + v3 = vld1q_lane_f64( a_loc + inca * 7, v3, 1 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float64x2_t vkappa = vld1q_dup_f64( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + float64x2_t v3 = vld1q_f64( a_loc + 6 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v6 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v7 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + v6 = vld1q_f64( a_loc + inca * 6 ); + v7 = vld1q_f64( a_loc + inca * 7 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + v4 = vmulq_f64( v4, vkappa ); + v5 = vmulq_f64( v5, vkappa ); + v6 = vmulq_f64( v6, vkappa ); + v7 = vmulq_f64( v7, vkappa ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd3_1 = vtrn1q_f64( v6, v7 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + float64x2_t vd3_2 = vtrn2q_f64( v6, v7 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + vst1q_f64( p_loc + 6, vd3_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + vst1q_f64( p_loc + 6, vd3_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + v3 = vld1q_lane_f64( a_loc + inca * 6, v3, 0 ); + v3 = vld1q_lane_f64( a_loc + inca * 7, v3, 1 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(dscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + +//bli_dfprintm( stdout, "packm 8xk ker: a_packed", cdim0, k0_max, p, 1, ldp0, "%5.2f", "" ); + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_s12xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_s12xk.c new file mode 100644 index 0000000000..3718772473 --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_s12xk.c @@ -0,0 +1,435 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_2 _Pragma("unroll 2") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_2 _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_2 +#endif + +void bli_spackm_armv8a_int_12xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + float* restrict kappa, + float* restrict a, inc_t inca0, inc_t lda0, + float* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 12; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + float* a_loc = a; + float* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_seq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + float32x4_t v2 = vld1q_f32( a_loc + 8 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v8 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v9 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v10 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v11 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + v8 = vld1q_f32( a_loc + inca * 8 ); + v9 = vld1q_f32( a_loc + inca * 9 ); + v10 = vld1q_f32( a_loc + inca * 10 ); + v11 = vld1q_f32( a_loc + inca * 11 ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 8-11 + vt0 = vtrn1q_f32( v8, v9 ); + vt1 = vtrn2q_f32( v8, v9 ); + vt2 = vtrn1q_f32( v10, v11 ); + vt3 = vtrn2q_f32( v10, v11 ); + v8 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v9 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v10 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v11 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + vst1q_f32( p_loc + 8, v8 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + vst1q_f32( p_loc + 8, v9 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + vst1q_f32( p_loc + 8, v10 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + vst1q_f32( p_loc + 8, v11 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + v2 = vld1q_lane_f32( a_loc + inca * 8 , v2, 0 ); + v2 = vld1q_lane_f32( a_loc + inca * 9 , v2, 1 ); + v2 = vld1q_lane_f32( a_loc + inca * 10, v2, 2 ); + v2 = vld1q_lane_f32( a_loc + inca * 11, v2, 3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float32x4_t vkappa = vld1q_dup_f32( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + float32x4_t v2 = vld1q_f32( a_loc + 8 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v8 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v9 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v10 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v11 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + v8 = vld1q_f32( a_loc + inca * 8 ); + v9 = vld1q_f32( a_loc + inca * 9 ); + v10 = vld1q_f32( a_loc + inca * 10 ); + v11 = vld1q_f32( a_loc + inca * 11 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + v3 = vmulq_f32( v3, vkappa ); + v4 = vmulq_f32( v4, vkappa ); + v5 = vmulq_f32( v5, vkappa ); + v6 = vmulq_f32( v6, vkappa ); + v7 = vmulq_f32( v7, vkappa ); + v8 = vmulq_f32( v8, vkappa ); + v9 = vmulq_f32( v9, vkappa ); + v10 = vmulq_f32( v10, vkappa ); + v11 = vmulq_f32( v11, vkappa ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 8-11 + vt0 = vtrn1q_f32( v8, v9 ); + vt1 = vtrn2q_f32( v8, v9 ); + vt2 = vtrn1q_f32( v10, v11 ); + vt3 = vtrn2q_f32( v10, v11 ); + v8 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v9 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v10 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v11 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + vst1q_f32( p_loc + 8, v8 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + vst1q_f32( p_loc + 8, v9 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + vst1q_f32( p_loc + 8, v10 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + vst1q_f32( p_loc + 8, v11 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + v2 = vld1q_lane_f32( a_loc + inca * 8 , v2, 0 ); + v2 = vld1q_lane_f32( a_loc + inca * 9 , v2, 1 ); + v2 = vld1q_lane_f32( a_loc + inca * 10, v2, 2 ); + v2 = vld1q_lane_f32( a_loc + inca * 11, v2, 3 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(sscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + float* restrict p_edge = p + (i )*1; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + float* restrict p_edge = p + (j )*ldp; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_s8xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_s8xk.c new file mode 100644 index 0000000000..3d363c2d8d --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_s8xk.c @@ -0,0 +1,373 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_4 _Pragma("unroll 4") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_4 _Pragma("GCC unroll 4") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_4 +#endif + +void bli_spackm_armv8a_int_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + float* restrict kappa, + float* restrict a, inc_t inca0, inc_t lda0, + float* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 8; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + float* a_loc = a; + float* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_seq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_4 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float32x4_t vkappa = vld1q_dup_f32( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_4 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + v3 = vmulq_f32( v3, vkappa ); + v4 = vmulq_f32( v4, vkappa ); + v5 = vmulq_f32( v5, vkappa ); + v6 = vmulq_f32( v6, vkappa ); + v7 = vmulq_f32( v7, vkappa ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(sscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + float* restrict p_edge = p + (i )*1; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + float* restrict p_edge = p + (j )*ldp; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/frame/1m/unpackm/bli_unpackm_unb_var1.h b/kernels/armv8a/3/armv8a_asm_d2x2.h similarity index 76% rename from frame/1m/unpackm/bli_unpackm_unb_var1.h rename to kernels/armv8a/3/armv8a_asm_d2x2.h index 5119aaa7ff..5bb0bb4d39 100644 --- a/frame/1m/unpackm/bli_unpackm_unb_var1.h +++ b/kernels/armv8a/3/armv8a_asm_d2x2.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -30,31 +31,25 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ -void bli_unpackm_unb_var1 - ( - obj_t* p, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread - ); - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffp, \ - uplo_t uplop, \ - trans_t transp, \ - dim_t m, \ - dim_t n, \ - void* p, inc_t rs_p, inc_t cs_p, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROT_BASIC0( unpackm_unb_var1 ) +/* C A B + * || <- | * -- + * || | + * + * or: + * C B * A + * -- <- | -- + * -- | + */ +#define DGEMM_2X2_NANOKERNEL(C0,C1,A,B) \ +" fmla v"#C0".2d, v"#A".2d, v"#B".d[0] \n\t" \ +" fmla v"#C1".2d, v"#A".2d, v"#B".d[1] \n\t" + +#define SGEMM_4X4_NANOKERNEL(C0,C1,C2,C3,A,B) \ +" fmla v"#C0".4s, v"#A".4s, v"#B".s[0] \n\t" \ +" fmla v"#C1".4s, v"#A".4s, v"#B".s[1] \n\t" \ +" fmla v"#C2".4s, v"#A".4s, v"#B".s[2] \n\t" \ +" fmla v"#C3".4s, v"#A".4s, v"#B".s[3] \n\t" diff --git a/kernels/armv8a/3/armv8a_asm_utils.h b/kernels/armv8a/3/armv8a_asm_utils.h new file mode 100644 index 0000000000..0c405dfd26 --- /dev/null +++ b/kernels/armv8a/3/armv8a_asm_utils.h @@ -0,0 +1,119 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Apple/Clang's local label requirements. +#if defined(__APPLE__) || defined(__clang__) +#define LABEL(str) " L" #str"%=: \n\t" +#define BEQ(str) "b.eq L" #str"%= \n\t" +#define BNE(str) "b.ne L" #str"%= \n\t" +#define BRANCH(str) "b L" #str"%= \n\t" +#else +#define LABEL(str) " ." #str": \n\t" +#define BEQ(str) "b.eq ." #str" \n\t" +#define BNE(str) "b.ne ." #str" \n\t" +#define BRANCH(str) "b ." #str" \n\t" +#endif + +// Clear vectors. +#define CLEAR1V(V) \ +" dup v"#V".2d, xzr \n\t" +#define CLEAR2V(V0,V1) \ + CLEAR1V(V0) \ + CLEAR1V(V1) +#define CLEAR4V(V0,V1,V2,V3) \ + CLEAR2V(V0,V1) \ + CLEAR2V(V2,V3) +#define CLEAR8V(V0,V1,V2,V3,V4,V5,V6,V7) \ + CLEAR4V(V0,V1,V2,V3) \ + CLEAR4V(V4,V5,V6,V7) + +// Scale vectors. +#define DSCALE1V(V,A,IDX) \ +" fmul v"#V".2d, v"#V".2d, v"#A".d["#IDX"] \n\t" +#define DSCALE2V(V0,V1,A,IDX) \ + DSCALE1V(V0,A,IDX) \ + DSCALE1V(V1,A,IDX) +#define DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V0,V1,A,IDX) \ + DSCALE2V(V2,V3,A,IDX) +#define DSCALE8V(V0,V1,V2,V3,V4,V5,V6,V7,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) + +// Scale-accumulate. +#define DSCALEA1V(D,S,A,IDX) \ +" fmla v"#D".2d, v"#S".2d, v"#A".d["#IDX"] \n\t" +#define DSCALEA2V(D0,D1,S0,S1,A,IDX) \ + DSCALEA1V(D0,S0,A,IDX) \ + DSCALEA1V(D1,S1,A,IDX) +#define DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D0,D1,S0,S1,A,IDX) \ + DSCALEA2V(D2,D3,S2,S3,A,IDX) +#define DSCALEA8V(D0,D1,D2,D3,D4,D5,D6,D7,S0,S1,S2,S3,S4,S5,S6,S7,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) + +// Load one line. +#define DLOAD1V(V,ADDR,SHIFT) \ +" ldr q"#V", ["#ADDR", #"#SHIFT"] \n\t" +#define DLOAD2V(V0,V1,ADDR,SHIFT) \ + DLOAD1V(V0,ADDR,SHIFT) \ + DLOAD1V(V1,ADDR,SHIFT+16) +#define DLOAD4V(V0,V1,V2,V3,ADDR,SHIFT) \ + DLOAD2V(V0,V1,ADDR,SHIFT) \ + DLOAD2V(V2,V3,ADDR,SHIFT+32) + +// Generic: load one line. +#define DLOAD1V_GATHER_ELMFWD(V,ADDR,INC) \ +" ld1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \ +" ld1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t" + +// Store one line. +#define DSTORE1V(V,ADDR,SHIFT) \ +" str q"#V", ["#ADDR", #"#SHIFT"] \n\t" +#define DSTORE2V(V0,V1,ADDR,SHIFT) \ + DSTORE1V(V0,ADDR,SHIFT) \ + DSTORE1V(V1,ADDR,SHIFT+16) +#define DSTORE4V(V0,V1,V2,V3,ADDR,SHIFT) \ + DSTORE2V(V0,V1,ADDR,SHIFT) \ + DSTORE2V(V2,V3,ADDR,SHIFT+32) + +// Generic: store one line. +#define DSTORE1V_SCATTER_ELMFWD(V,ADDR,INC) \ +" st1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \ +" st1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t" + + diff --git a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c new file mode 100644 index 0000000000..4d9a888178 --- /dev/null +++ b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c @@ -0,0 +1,1488 @@ + /* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "armv8a_asm_utils.h" + +/* + o 4x4 Single precision micro-kernel fully functional. + o Runnable on ARMv8, compiled with aarch64 GCC. + o Use it together with the armv8 BLIS configuration. + o Tested on Juno board. Around 7.3 GFLOPS @ 1.1 GHz. + + December 2014. + + * UPDATE NOVEMBER 2015 + * Micro-kernel changed to 8x12 + * Tested on Juno Board. Around 8.1 GFLOPS, 1 x A57 core @ 1.1 GHz. + * Tested on Juno Board. Around 15.9 GFLOPS, 2 x A57 cores @ 1.1 GHz. + * Tested on Juno board. Around 3.1 GFLOPS, 1 x A53 core @ 850 MHz. + * Tested on Juno board. Around 12 GFLOPS, 4 x A53 cores @ 850 MHz. +*/ +void bli_sgemm_armv8a_asm_8x12 + ( + dim_t m, + dim_t n, + dim_t k, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + GEMM_UKR_SETUP_CT( s, 8, 12, false ); + + + __asm__ volatile + ( + " \n\t" + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A. + " ldr x1,%[baddr] \n\t" // Load address of B. + " ldr x2,%[caddr] \n\t" // Load address of C. + " \n\t" + " ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). + " ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c. + " lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. + " \n\t" + // " ldr x14,%[rs_c] \n\t" // Load rs_c. + // " lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). + " \n\t" + " add x16,x2,x10 \n\t" //Load address Column 1 of C + " add x17,x16,x10 \n\t" //Load address Column 2 of C + " add x19,x17,x10 \n\t" //Load address Column 3 of C + " add x20,x19,x10 \n\t" //Load address Column 4 of C + " add x21,x20,x10 \n\t" //Load address Column 5 of C + " add x22,x21,x10 \n\t" //Load address Column 6 of C + " add x23,x22,x10 \n\t" //Load address Column 7 of C + " add x24,x23,x10 \n\t" //Load address Column 8 of C + " add x25,x24,x10 \n\t" //Load address Column 9 of C + " add x26,x25,x10 \n\t" //Load address Column 10 of C + " add x27,x26,x10 \n\t" //Load address Column 11 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x16] \n\t" // Prefetch c. + " prfm pldl1keep,[x17] \n\t" // Prefetch c. + " prfm pldl1keep,[x19] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " prfm pldl1keep,[x27] \n\t" // Prefetch c. + " \n\t" + " dup v8.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #192] \n\t" + " dup v9.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v10.4s, wzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v11.4s, wzr \n\t" // Vector for accummulating column 1 + " dup v12.4s, wzr \n\t" // Vector for accummulating column 2 + " dup v13.4s, wzr \n\t" // Vector for accummulating column 2 + " \n\t" + " dup v14.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #128] \n\t" + " dup v15.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v16.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v17.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v18.4s, wzr \n\t" // Vector for accummulating column 5 + " dup v19.4s, wzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v20.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v21.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v22.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v23.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v24.4s, wzr \n\t" // Vector for accummulating column 8 + " dup v25.4s, wzr \n\t" // Vector for accummulating column 8 + " \n\t" + " dup v26.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v27.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v28.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v29.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v30.4s, wzr \n\t" // Vector for accummulating column 11 + " dup v31.4s, wzr \n\t" // Vector for accummulating column 11 + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(SCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" + " ldr q1, [x0, #16] \n\t" // Load a + " \n\t" + " ldr q2, [x1] \n\t" // Load b + " ldr q3, [x1, #16] \n\t" + " ldr q4, [x1, #32] \n\t" + " \n\t" + " add x0, x0, #32 \n\t" //update address of A + " add x1, x1, #48 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(SLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(SLOOPKITER) // Body of the k_iter loop. + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #336] \n\t" + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #400] \n\t" + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #464] \n\t" + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #224] \n\t" + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #288] \n\t" + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " ldr q0, [x0, #96] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #112] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #144] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #160] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #176] \n\t" + " add x1, x1, #192 \n\t" + " add x0, x0, #128 \n\t" + " \n\t" //End It 4 + " sub x5,x5,1 \n\t" // i-=1. + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(SLOOPKITER) + " \n\t" + LABEL(SLASTITER) // Last iteration of k_iter loop. + " \n\t" + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " add x1, x1, #144 \n\t" + " add x0, x0, #96 \n\t" + " \n\t" //End It 4 + " \n\t" + LABEL(SCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(SPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(SLOOPKLEFT) // Body of the left iterations + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " \n\t" + " ldr q2, [x1],#16 \n\t" // Load b + " ldr q3, [x1],#16 \n\t" + " ldr q4, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" // i = i-1. + " \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(SLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(SPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address. + " ldr x1,%[beta] \n\t" // Beta address. + " \n\t" + " ld1r {v6.4s},[x0] \n\t" // Load alpha. + " ld1r {v7.4s},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Pointer to next block of A. + " ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. + " \n\t" + LABEL(SCOLSTORED) // C is column-major. + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x16] \n\t" //Load column 1 of C + " ldr q3, [x16, #16] \n\t" + " ldr q4, [x17] \n\t" //Load column 2 of C + " ldr q5, [x17, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x16] \n\t" //Store column 1 of C + " str q3, [x16, #16] \n\t" + " str q4, [x17] \n\t" //Store column 2 of C + " str q5, [x17, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x19] \n\t" //Load column 3 of C + " ldr q9, [x19, #16] \n\t" + " ldr q10, [x20] \n\t" //Load column 4 of C + " ldr q11, [x20, #16] \n\t" + " ldr q12, [x21] \n\t" //Load column 5 of C + " ldr q13, [x21, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x19] \n\t" //Store column 3 of C + " str q9, [x19, #16] \n\t" + " str q10, [x20] \n\t" //Store column 4 of C + " str q11, [x20, #16] \n\t" + " str q12, [x21] \n\t" //Store column 5 of C + " str q13, [x21, #16] \n\t" + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x22] \n\t" //Load column 6 of C + " ldr q1, [x22, #16] \n\t" + " ldr q2, [x23] \n\t" //Load column 7 of C + " ldr q3, [x23, #16] \n\t" + " ldr q4, [x24] \n\t" //Load column 8 of C + " ldr q5, [x24, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x22] \n\t" //Store column 6 of C + " str q1, [x22, #16] \n\t" + " str q2, [x23] \n\t" //Store column 7 of C + " str q3, [x23, #16] \n\t" + " str q4, [x24] \n\t" //Store column 8 of C + " str q5, [x24, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 9 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x26] \n\t" //Load column 10 of C + " ldr q11, [x26, #16] \n\t" + " ldr q12, [x27] \n\t" //Load column 11 of C + " ldr q13, [x27, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 9 of C + " str q9, [x25, #16] \n\t" + " str q10, [x26] \n\t" //Store column 10 of C + " str q11, [x26, #16] \n\t" + " str q12, [x27] \n\t" //Store column 11 of C + " str q13, [x27, #16] \n\t" + " \n\t" + " \n\t" + // BRANCH(SEND) // Done. + // LABEL(SEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [a_next] "m" (a_next), // 9 + [b_next] "m" (b_next) // 10 + :// Register clobber list + "x0", "x1", "x2", + "x5", "x6", "x10", + "x16","x17","x19","x20", + "x21","x22","x23","x24", + "x25","x26","x27", + "v0", "v1", "v2", "v3", + "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11", + "v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + + GEMM_UKR_FLUSH_CT( s ); +} + + +/* + o 4x4 Double precision micro-kernel NOT fully functional yet. + o Runnable on ARMv8, compiled with aarch64 GCC. + o Use it together with the armv8 BLIS configuration. + o Tested on Juno board. Around 3 GFLOPS @ 1.1 GHz. + + December 2014. + + * UPDATE OCTOBER 2015: Now is fully functional. + * Tested on Juno board. Around 5.6 GFLOPS, 2 A57 cores @ 1.1 GHz. + * Tested on Juno board. Around 4 GFLOPS, 4 A53 cores @ 850 MHz. + + * UPDATE NOVEMBER 2015 + * Micro-kernel changed to 6x8 + * Tested on Juno Board. Around 4 GFLOPS, 1 x A57 core @ 1.1 GHz. + * Tested on Juno Board. Around 7.6 GFLOPS, 2 x A57 cores @ 1.1 GHz. + * Tested on Juno board. Around 1.5 GFLOPS, 1 x A53 core @ 850 MHz. + * Tested on Juno board. Around 5.5 GFLOPS, 4 x A53 cores @ 850 MHz. +*/ +void bli_dgemm_armv8a_asm_6x8 + ( + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + GEMM_UKR_SETUP_CT( d, 6, 8, false ); + + __asm__ volatile + ( + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A + " ldr x1,%[baddr] \n\t" // Load address of B + " ldr x2,%[caddr] \n\t" // Load address of C + " \n\t" + " ldr x5,%[k_iter] \n\t" // Init guard (k_iter) + " ldr x6,%[k_left] \n\t" // Init guard (k_iter) + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c + " lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) + " \n\t" + // " ldr x14,%[rs_c] \n\t" // Load rs_c. + // " lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). + " \n\t" + " add x20,x2,x10 \n\t" //Load address Column 1 of C + " add x21,x20,x10 \n\t" //Load address Column 2 of C + " add x22,x21,x10 \n\t" //Load address Column 3 of C + " add x23,x22,x10 \n\t" //Load address Column 4 of C + " add x24,x23,x10 \n\t" //Load address Column 5 of C + " add x25,x24,x10 \n\t" //Load address Column 6 of C + " add x26,x25,x10 \n\t" //Load address Column 7 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " \n\t" + " dup v8.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v9.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v10.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #384] \n\t" + " dup v11.2d, xzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #448] \n\t" + " dup v12.2d, xzr \n\t" // Vector for accummulating column 1 + " dup v13.2d, xzr \n\t" // Vector for accummulating column 1 + " \n\t" + " dup v14.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v15.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #256] \n\t" + " dup v16.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #320] \n\t" + " dup v17.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v18.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v19.2d, xzr \n\t" // Vector for accummulating column 3 + " \n\t" + " dup v20.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v21.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v22.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v23.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v24.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v25.2d, xzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v26.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v27.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v28.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v29.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v30.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v31.2d, xzr \n\t" // Vector for accummulating column 7 + " \n\t" + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(DCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" // Load a + " ldr q1, [x0, #16] \n\t" + " ldr q2, [x0, #32] \n\t" + " \n\t" + " ldr q3, [x1] \n\t" // Load b + " ldr q4, [x1, #16] \n\t" + " ldr q5, [x1, #32] \n\t" + " ldr q6, [x1, #48] \n\t" + " \n\t" + " add x0, x0, #48 \n\t" //update address of A + " add x1, x1, #64 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(DLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(DLOOP) // Body + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #512] \n\t" + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #576] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #640] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #336] \n\t" + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #400] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #464] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #192] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #176] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #208] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #224] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #144] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #160] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #240] \n\t" + " \n\t" //End it 4 + " add x0, x0, #192 \n\t" + " add x1, x1, #256 \n\t" + " \n\t" + " sub x5,x5,1 \n\t" // i-=1 + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(DLOOP) + " \n\t" + LABEL(DLASTITER) + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " add x1, x1, #192 \n\t" + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " \n\t" //End it 4 + " add x0, x0, #144 \n\t" + " \n\t" + LABEL(DCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(DPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(DLOOPKLEFT) + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " ldr q2, [x0],#16 \n\t" + " \n\t" + " ldr q3, [x1],#16 \n\t" // Load b + " ldr q4, [x1],#16 \n\t" + " ldr q5, [x1],#16 \n\t" + " ldr q6, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(DLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(DPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address + " ldr x1,%[beta] \n\t" // Beta address + " \n\t" + " ld1r {v6.2d},[x0] \n\t" // Load alpha. + " ld1r {v7.2d},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Next A address for later use. + " ldr x1,%[b_next] \n\t" // Next B address for later use. + " \n\t" + LABEL(DCOLSTORED) // C is column-major. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x2, #32] \n\t" + " \n\t" + " ldr q3, [x20] \n\t" //Load column 1 of C + " ldr q4, [x20, #16] \n\t" + " ldr q5, [x20, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x2, #32] \n\t" + " \n\t" + " str q3, [x20] \n\t" //Store column 1 of C + " str q4, [x20, #16] \n\t" + " str q5, [x20, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x21] \n\t" //Load column 2 of C + " ldr q9, [x21, #16] \n\t" + " ldr q10, [x21, #32] \n\t" + " \n\t" + " ldr q11, [x22] \n\t" //Load column 3 of C + " ldr q12, [x22, #16] \n\t" + " ldr q13, [x22, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x21] \n\t" //Store column 2 of C + " str q9, [x21, #16] \n\t" + " str q10, [x21, #32] \n\t" + " \n\t" + " str q11, [x22] \n\t" //Store column 3 of C + " str q12, [x22, #16] \n\t" + " str q13, [x22, #32] \n\t" + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x23] \n\t" //Load column 4 of C + " ldr q1, [x23, #16] \n\t" + " ldr q2, [x23, #32] \n\t" + " \n\t" + " ldr q3, [x24] \n\t" //Load column 5 of C + " ldr q4, [x24, #16] \n\t" + " ldr q5, [x24, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x23] \n\t" //Store column 4 of C + " str q1, [x23, #16] \n\t" + " str q2, [x23, #32] \n\t" + " \n\t" + " str q3, [x24] \n\t" //Store column 5 of C + " str q4, [x24, #16] \n\t" + " str q5, [x24, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 6 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x25, #32] \n\t" + " \n\t" + " ldr q11, [x26] \n\t" //Load column 7 of C + " ldr q12, [x26, #16] \n\t" + " ldr q13, [x26, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 6 of C + " str q9, [x25, #16] \n\t" + " str q10, [x25, #32] \n\t" + " \n\t" + " str q11, [x26] \n\t" //Store column 7 of C + " str q12, [x26, #16] \n\t" + " str q13, [x26, #32] \n\t" + " \n\t" + // BRANCH(DEND) + // LABEL(DEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 6 + [cs_c] "m" (cs_c), // 7 + [a_next] "m" (a_next), // 8 + [b_next] "m" (b_next) // 9 + :// Register clobber list + "x0","x1","x2", + "x5","x6","x10", + "x16","x17","x20", + "x21","x22","x23", + "x24","x25","x26","x27", + "v0","v1","v2", + "v3","v4","v5", + "v6","v7","v8", + "v9","v10","v11", + "v12","v13","v14", + "v15","v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + + GEMM_UKR_FLUSH_CT( d ); +} + + +#if 0 +void bli_cgemm_armv8a_opt_4x4 + ( + dim_t m, + dim_t n, + dim_t k, + scomplex* restrict alpha, + scomplex* restrict a, + scomplex* restrict b, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +} + +void bli_zgemm_armv8a_opt_4x4 + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +} + +#endif + diff --git a/kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c b/kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c deleted file mode 100644 index 8f5ec76f60..0000000000 --- a/kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c +++ /dev/null @@ -1,2114 +0,0 @@ - /* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -*/ - -#include "blis.h" - -/* - o 4x4 Single precision micro-kernel fully functional. - o Runnable on ARMv8, compiled with aarch64 GCC. - o Use it together with the armv8 BLIS configuration. - o Tested on Juno board. Around 7.3 GFLOPS @ 1.1 GHz. - - December 2014. - - * UPDATE NOVEMBER 2015 - * Micro-kernel changed to 8x12 - * Tested on Juno Board. Around 8.1 GFLOPS, 1 x A57 core @ 1.1 GHz. - * Tested on Juno Board. Around 15.9 GFLOPS, 2 x A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 3.1 GFLOPS, 1 x A53 core @ 850 MHz. - * Tested on Juno board. Around 12 GFLOPS, 4 x A53 cores @ 850 MHz. -*/ -void bli_sgemm_armv8a_asm_8x12 - ( - dim_t k0, - float* restrict alpha, - float* restrict a, - float* restrict b, - float* restrict beta, - float* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ - void* a_next = bli_auxinfo_next_a( data ); - void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - -__asm__ volatile -( -" \n\t" -" \n\t" -" ldr x0,%[aaddr] \n\t" // Load address of A. -" ldr x1,%[baddr] \n\t" // Load address of B. -" ldr x2,%[caddr] \n\t" // Load address of C. -" \n\t" -" ldr x3,%[a_next] \n\t" // Pointer to next block of A. -" ldr x4,%[b_next] \n\t" // Pointer to next pointer of B. -" \n\t" -" ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). -" ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). -" \n\t" -" ldr x7,%[alpha] \n\t" // Alpha address. -" ldr x8,%[beta] \n\t" // Beta address. -" \n\t" -" ldr x9,%[cs_c] \n\t" // Load cs_c. -" lsl x10,x9,#2 \n\t" // cs_c * sizeof(float) -- AUX. -" \n\t" -" ldr x13,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x13,#2 \n\t" // rs_c * sizeof(float). -" \n\t" -" add x16,x2,x10 \n\t" //Load address Column 1 of C -" add x17,x16,x10 \n\t" //Load address Column 2 of C -" add x18,x17,x10 \n\t" //Load address Column 3 of C -" add x19,x18,x10 \n\t" //Load address Column 4 of C -" add x20,x19,x10 \n\t" //Load address Column 5 of C -" add x21,x20,x10 \n\t" //Load address Column 6 of C -" add x22,x21,x10 \n\t" //Load address Column 7 of C -" add x23,x22,x10 \n\t" //Load address Column 8 of C -" add x24,x23,x10 \n\t" //Load address Column 9 of C -" add x25,x24,x10 \n\t" //Load address Column 10 of C -" add x26,x25,x10 \n\t" //Load address Column 11 of C -" \n\t" -" ldr q0, [x0] \n\t" -" ldr q1, [x0, #16] \n\t" // Load a -" \n\t" -" ldr q2, [x1] \n\t" // Load b -" ldr q3, [x1, #16] \n\t" -" ldr q4, [x1, #32] \n\t" -" \n\t" -" prfm pldl1keep,[x2] \n\t" // Prefetch c. -" prfm pldl1keep,[x16] \n\t" // Prefetch c. -" prfm pldl1keep,[x17] \n\t" // Prefetch c. -" prfm pldl1keep,[x18] \n\t" // Prefetch c. -" prfm pldl1keep,[x19] \n\t" // Prefetch c. -" prfm pldl1keep,[x20] \n\t" // Prefetch c. -" prfm pldl1keep,[x21] \n\t" // Prefetch c. -" prfm pldl1keep,[x22] \n\t" // Prefetch c. -" prfm pldl1keep,[x23] \n\t" // Prefetch c. -" prfm pldl1keep,[x24] \n\t" // Prefetch c. -" prfm pldl1keep,[x25] \n\t" // Prefetch c. -" prfm pldl1keep,[x26] \n\t" // Prefetch c. -" \n\t" -" dup v8.4s, wzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #192] \n\t" -" dup v9.4s, wzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #256] \n\t" -" dup v10.4s, wzr \n\t" // Vector for accummulating column 1 -" prfm PLDL1KEEP, [x1, #320] \n\t" -" dup v11.4s, wzr \n\t" // Vector for accummulating column 1 -" dup v12.4s, wzr \n\t" // Vector for accummulating column 2 -" dup v13.4s, wzr \n\t" // Vector for accummulating column 2 -" \n\t" -" dup v14.4s, wzr \n\t" // Vector for accummulating column 3 -" prfm PLDL1KEEP, [x0, #128] \n\t" -" dup v15.4s, wzr \n\t" // Vector for accummulating column 3 -" prfm PLDL1KEEP, [x0, #192] \n\t" -" dup v16.4s, wzr \n\t" // Vector for accummulating column 4 -" dup v17.4s, wzr \n\t" // Vector for accummulating column 4 -" dup v18.4s, wzr \n\t" // Vector for accummulating column 5 -" dup v19.4s, wzr \n\t" // Vector for accummulating column 5 -" \n\t" -" dup v20.4s, wzr \n\t" // Vector for accummulating column 6 -" dup v21.4s, wzr \n\t" // Vector for accummulating column 6 -" dup v22.4s, wzr \n\t" // Vector for accummulating column 7 -" dup v23.4s, wzr \n\t" // Vector for accummulating column 7 -" dup v24.4s, wzr \n\t" // Vector for accummulating column 8 -" dup v25.4s, wzr \n\t" // Vector for accummulating column 8 -" \n\t" -" dup v26.4s, wzr \n\t" // Vector for accummulating column 9 -" dup v27.4s, wzr \n\t" // Vector for accummulating column 9 -" dup v28.4s, wzr \n\t" // Vector for accummulating column 10 -" dup v29.4s, wzr \n\t" // Vector for accummulating column 10 -" dup v30.4s, wzr \n\t" // Vector for accummulating column 11 -" dup v31.4s, wzr \n\t" // Vector for accummulating column 11 -" \n\t" -" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -" beq .SCONSIDERKLEFT \n\t" -" \n\t" -"add x0, x0, #32 \n\t" //update address of A -"add x1, x1, #48 \n\t" //update address of B -" \n\t" -" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -" beq .SLASTITER \n\t" // (as loop is do-while-like). -" \n\t" -" .SLOOPKITER: \n\t" // Body of the k_iter loop. -" \n\t" -" ldr q5, [x0] \n\t" -" fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #16] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #336] \n\t" -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #400] \n\t" -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #464] \n\t" -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #16] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #32] \n\t" -" \n\t" //End It 1 -" \n\t" -" ldr q0, [x0, #32] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #48] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #48] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x0, #224] \n\t" -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x0, #288] \n\t" -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #80] \n\t" -" \n\t" //End It 2 -" \n\t" -" ldr q5, [x0, #64] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #80] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #96] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #112] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #128] \n\t" -" \n\t" //End It 3 -" \n\t" -" ldr q0, [x0, #96] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #112] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #144] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #160] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #176] \n\t" -" add x1, x1, #192 \n\t" -" add x0, x0, #128 \n\t" -" \n\t" //End It 4 -" sub x5,x5,1 \n\t" // i-=1. -" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -" bne .SLOOPKITER \n\t" -" \n\t" -" .SLASTITER: \n\t" // Last iteration of k_iter loop. -" \n\t" -" \n\t" -" ldr q5, [x0] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #16] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #16] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #32] \n\t" -" \n\t" //End It 1 -" \n\t" -" ldr q0, [x0, #32] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #48] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #48] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #80] \n\t" -" \n\t" //End It 2 -" \n\t" -" ldr q5, [x0, #64] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #80] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #96] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #112] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #128] \n\t" -" \n\t" //End It 3 -" \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" add x1, x1, #144 \n\t" -" add x0, x0, #96 \n\t" -" \n\t" //End It 4 -" \n\t" -" .SCONSIDERKLEFT: \n\t" -" cmp x6,0 \n\t" // If k_left == 0, we are done. -" beq .SPOSTACCUM \n\t" // else, we enter the k_left loop. -" \n\t" -" .SLOOPKLEFT: \n\t" // Body of the left iterations -" \n\t" -" ldr q0, [x0],#16 \n\t" -" ldr q1, [x0],#16 \n\t" // Load a -" \n\t" -" ldr q2, [x1],#16 \n\t" // Load b -" ldr q3, [x1],#16 \n\t" -" ldr q4, [x1],#16 \n\t" -" \n\t" -" sub x6,x6,1 \n\t" // i = i-1. -" \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" \n\t" -" cmp x6,0 \n\t" // Iterate again. -" bne .SLOOPKLEFT \n\t" // if i!=0. -" \n\t" -" .SPOSTACCUM: \n\t" -" \n\t" -" ld1r {v6.4s},[x7] \n\t" // Load alpha. -" ld1r {v7.4s},[x8] \n\t" // Load beta -" \n\t" -" cmp x13,#1 \n\t" // If rs_c != 1 (column-major) -" bne .SGENSTORED \n\t" -" \n\t" -" .SCOLSTORED: \n\t" // C is column-major. -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS1 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x2] \n\t" //Load column 0 of C -" ldr q1, [x2, #16] \n\t" -" ldr q2, [x16] \n\t" //Load column 1 of C -" ldr q3, [x16, #16] \n\t" -" ldr q4, [x17] \n\t" //Load column 2 of C -" ldr q5, [x17, #16] \n\t" -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROCOLSTOREDS1: \n\t" -" \n\t" -" fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x2] \n\t" //Store column 0 of C -" str q1, [x2, #16] \n\t" -" str q2, [x16] \n\t" //Store column 1 of C -" str q3, [x16, #16] \n\t" -" str q4, [x17] \n\t" //Store column 2 of C -" str q5, [x17, #16] \n\t" -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS2 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x18] \n\t" //Load column 3 of C -" ldr q9, [x18, #16] \n\t" -" ldr q10, [x19] \n\t" //Load column 4 of C -" ldr q11, [x19, #16] \n\t" -" ldr q12, [x20] \n\t" //Load column 5 of C -" ldr q13, [x20, #16] \n\t" -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROCOLSTOREDS2: \n\t" -" \n\t" -" fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x18] \n\t" //Store column 3 of C -" str q9, [x18, #16] \n\t" -" str q10, [x19] \n\t" //Store column 4 of C -" str q11, [x19, #16] \n\t" -" str q12, [x20] \n\t" //Store column 5 of C -" str q13, [x20, #16] \n\t" -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS3 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x21] \n\t" //Load column 6 of C -" ldr q1, [x21, #16] \n\t" -" ldr q2, [x22] \n\t" //Load column 7 of C -" ldr q3, [x22, #16] \n\t" -" ldr q4, [x23] \n\t" //Load column 8 of C -" ldr q5, [x23, #16] \n\t" -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROCOLSTOREDS3: \n\t" -" \n\t" -" fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x21] \n\t" //Store column 6 of C -" str q1, [x21, #16] \n\t" -" str q2, [x22] \n\t" //Store column 7 of C -" str q3, [x22, #16] \n\t" -" str q4, [x23] \n\t" //Store column 8 of C -" str q5, [x23, #16] \n\t" -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS4 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x24] \n\t" //Load column 9 of C -" ldr q9, [x24, #16] \n\t" -" ldr q10, [x25] \n\t" //Load column 10 of C -" ldr q11, [x25, #16] \n\t" -" ldr q12, [x26] \n\t" //Load column 11 of C -" ldr q13, [x26, #16] \n\t" -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROCOLSTOREDS4: \n\t" -" \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" -" \n\t" -" fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x24] \n\t" //Store column 9 of C -" str q9, [x24, #16] \n\t" -" str q10, [x25] \n\t" //Store column 10 of C -" str q11, [x25, #16] \n\t" -" str q12, [x26] \n\t" //Store column 11 of C -" str q13, [x26, #16] \n\t" -" \n\t" -" \n\t" -" b .SEND \n\t" // Done (TODO: this obviously needs to be moved down to remove jump). -" \n\t" -" \n\t" -" .SGENSTORED: \n\t" // C is general-stride stored. -" \n\t" -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS1 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x2 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x27],x14 \n\t" // Load c06 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x27],x14 \n\t" // Load c07 into quad and increment by rs_c. -" \n\t" -" mov x27, x16 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x27],x14 \n\t" // Load c16 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x27],x14 \n\t" // Load c17 into quad and increment by rs_c. -" \n\t" -" mov x27, x17 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x27],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x27],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x27],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x27],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x27],x14 \n\t" // Load c26 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x27],x14 \n\t" // Load c27 into quad and increment by rs_c. -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROGENSTOREDS1: \n\t" -" \n\t" -" fmla v0.4s, v8.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s, v9.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x2 \n\t" -" \n\t" -" st1 {v0.s}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.s}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v0.s}[2],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v0.s}[3],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v1.s}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v1.s}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. -" st1 {v1.s}[2],[x27],x14 \n\t" // Store c06 into quad and increment by rs_c. -" st1 {v1.s}[3],[x27],x14 \n\t" // Store c07 into quad and increment by rs_c. -" \n\t" -" mov x27, x16 \n\t" -" \n\t" -" st1 {v2.s}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v2.s}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v2.s}[2],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v2.s}[3],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v3.s}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v3.s}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. -" st1 {v3.s}[2],[x27],x14 \n\t" // Store c16 into quad and increment by rs_c. -" st1 {v3.s}[3],[x27],x14 \n\t" // Store c17 into quad and increment by rs_c. -" \n\t" -" mov x27, x17 \n\t" -" \n\t" -" st1 {v4.s}[0],[x27],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v4.s}[1],[x27],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v4.s}[2],[x27],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v4.s}[3],[x27],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v5.s}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v5.s}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. -" st1 {v5.s}[2],[x27],x14 \n\t" // Store c26 into quad and increment by rs_c. -" st1 {v5.s}[3],[x27],x14 \n\t" // Store c27 into quad and increment by rs_c. -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS2 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x18 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x27],x14 \n\t" // Load c36 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x27],x14 \n\t" // Load c37 into quad and increment by rs_c. -" \n\t" -" mov x27, x19 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x27],x14 \n\t" // Load c46 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x27],x14 \n\t" // Load c47 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x27],x14 \n\t" // Load c56 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x27],x14 \n\t" // Load c57 into quad and increment by rs_c. -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROGENSTOREDS2: \n\t" -" \n\t" -" fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x18 \n\t" -" \n\t" -" st1 {v8.s}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v8.s}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v8.s}[2],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v8.s}[3],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v9.s}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v9.s}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. -" st1 {v9.s}[2],[x27],x14 \n\t" // Store c36 into quad and increment by rs_c. -" st1 {v9.s}[3],[x27],x14 \n\t" // Store c37 into quad and increment by rs_c. -" \n\t" -" mov x27, x19 \n\t" -" \n\t" -" st1 {v10.s}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v10.s}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v10.s}[2],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v10.s}[3],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v11.s}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v11.s}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. -" st1 {v11.s}[2],[x27],x14 \n\t" // Store c46 into quad and increment by rs_c. -" st1 {v11.s}[3],[x27],x14 \n\t" // Store c47 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" -" \n\t" -" st1 {v12.s}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v12.s}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v12.s}[2],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v12.s}[3],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v13.s}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v13.s}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. -" st1 {v13.s}[2],[x27],x14 \n\t" // Store c56 into quad and increment by rs_c. -" st1 {v13.s}[3],[x27],x14 \n\t" // Store c57 into quad and increment by rs_c. -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS3 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x21 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x27],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x27],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x27],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x27],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x27],x14 \n\t" // Load c66 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x27],x14 \n\t" // Load c67 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x27],x14 \n\t" // Load c76 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x27],x14 \n\t" // Load c77 into quad and increment by rs_c. -" \n\t" -" mov x27, x23 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x27],x14 \n\t" // Load c80 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x27],x14 \n\t" // Load c81 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x27],x14 \n\t" // Load c82 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x27],x14 \n\t" // Load c83 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x27],x14 \n\t" // Load c84 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x27],x14 \n\t" // Load c85 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x27],x14 \n\t" // Load c86 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x27],x14 \n\t" // Load c87 into quad and increment by rs_c. -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROGENSTOREDS3: \n\t" -" \n\t" -" fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x21 \n\t" -" \n\t" -" st1 {v0.s}[0],[x27],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v0.s}[1],[x27],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v0.s}[2],[x27],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v0.s}[3],[x27],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v1.s}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v1.s}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. -" st1 {v1.s}[2],[x27],x14 \n\t" // Store c66 into quad and increment by rs_c. -" st1 {v1.s}[3],[x27],x14 \n\t" // Store c67 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" -" \n\t" -" st1 {v2.s}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v2.s}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v2.s}[2],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v2.s}[3],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v3.s}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v3.s}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. -" st1 {v3.s}[2],[x27],x14 \n\t" // Store c76 into quad and increment by rs_c. -" st1 {v3.s}[3],[x27],x14 \n\t" // Store c77 into quad and increment by rs_c. -" \n\t" -" mov x27, x23 \n\t" -" \n\t" -" st1 {v4.s}[0],[x27],x14 \n\t" // Store c80 into quad and increment by rs_c. -" st1 {v4.s}[1],[x27],x14 \n\t" // Store c81 into quad and increment by rs_c. -" st1 {v4.s}[2],[x27],x14 \n\t" // Store c82 into quad and increment by rs_c. -" st1 {v4.s}[3],[x27],x14 \n\t" // Store c83 into quad and increment by rs_c. -" st1 {v5.s}[0],[x27],x14 \n\t" // Store c84 into quad and increment by rs_c. -" st1 {v5.s}[1],[x27],x14 \n\t" // Store c85 into quad and increment by rs_c. -" st1 {v5.s}[2],[x27],x14 \n\t" // Store c86 into quad and increment by rs_c. -" st1 {v5.s}[3],[x27],x14 \n\t" // Store c87 into quad and increment by rs_c. -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS4 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x24 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x27],x14 \n\t" // Load c90 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x27],x14 \n\t" // Load c91 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x27],x14 \n\t" // Load c92 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x27],x14 \n\t" // Load c93 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x27],x14 \n\t" // Load c94 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x27],x14 \n\t" // Load c95 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x27],x14 \n\t" // Load c96 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x27],x14 \n\t" // Load c97 into quad and increment by rs_c. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x27],x14 \n\t" // Load c100 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x27],x14 \n\t" // Load c101 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x27],x14 \n\t" // Load c102 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x27],x14 \n\t" // Load c103 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x27],x14 \n\t" // Load c104 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x27],x14 \n\t" // Load c105 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x27],x14 \n\t" // Load c106 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x27],x14 \n\t" // Load c107 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x27],x14 \n\t" // Load c110 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x27],x14 \n\t" // Load c111 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x27],x14 \n\t" // Load c112 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x27],x14 \n\t" // Load c113 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x27],x14 \n\t" // Load c114 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x27],x14 \n\t" // Load c115 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x27],x14 \n\t" // Load c116 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x27],x14 \n\t" // Load c117 into quad and increment by rs_c. -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -" .SBETAZEROGENSTOREDS4: \n\t" -" \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" -" \n\t" -" fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x24 \n\t" -" \n\t" -" st1 {v8.s}[0],[x27],x14 \n\t" // Store c90 into quad and increment by rs_c. -" st1 {v8.s}[1],[x27],x14 \n\t" // Store c91 into quad and increment by rs_c. -" st1 {v8.s}[2],[x27],x14 \n\t" // Store c92 into quad and increment by rs_c. -" st1 {v8.s}[3],[x27],x14 \n\t" // Store c93 into quad and increment by rs_c. -" st1 {v9.s}[0],[x27],x14 \n\t" // Store c94 into quad and increment by rs_c. -" st1 {v9.s}[1],[x27],x14 \n\t" // Store c95 into quad and increment by rs_c. -" st1 {v9.s}[2],[x27],x14 \n\t" // Store c96 into quad and increment by rs_c. -" st1 {v9.s}[3],[x27],x14 \n\t" // Store c97 into quad and increment by rs_c. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" st1 {v10.s}[0],[x27],x14 \n\t" // Store c100 into quad and increment by rs_c. -" st1 {v10.s}[1],[x27],x14 \n\t" // Store c101 into quad and increment by rs_c. -" st1 {v10.s}[2],[x27],x14 \n\t" // Store c102 into quad and increment by rs_c. -" st1 {v10.s}[3],[x27],x14 \n\t" // Store c103 into quad and increment by rs_c. -" st1 {v11.s}[0],[x27],x14 \n\t" // Store c104 into quad and increment by rs_c. -" st1 {v11.s}[1],[x27],x14 \n\t" // Store c105 into quad and increment by rs_c. -" st1 {v11.s}[2],[x27],x14 \n\t" // Store c106 into quad and increment by rs_c. -" st1 {v11.s}[3],[x27],x14 \n\t" // Store c107 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" -" \n\t" -" st1 {v12.s}[0],[x27],x14 \n\t" // Store c110 into quad and increment by rs_c. -" st1 {v12.s}[1],[x27],x14 \n\t" // Store c111 into quad and increment by rs_c. -" st1 {v12.s}[2],[x27],x14 \n\t" // Store c112 into quad and increment by rs_c. -" st1 {v12.s}[3],[x27],x14 \n\t" // Store c113 into quad and increment by rs_c. -" st1 {v13.s}[0],[x27],x14 \n\t" // Store c114 into quad and increment by rs_c. -" st1 {v13.s}[1],[x27],x14 \n\t" // Store c115 into quad and increment by rs_c. -" st1 {v13.s}[2],[x27],x14 \n\t" // Store c116 into quad and increment by rs_c. -" st1 {v13.s}[3],[x27],x14 \n\t" // Store c147 into quad and increment by rs_c. -" \n\t" -" .SEND: \n\t" // Done! -" \n\t" -:// output operands (none) -:// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [a_next] "m" (a_next), // 9 - [b_next] "m" (b_next) // 10 -:// Register clobber list - "x0", "x1", "x2","x3","x4", - "x5", "x6", "x7", "x8", - "x9", "x10","x11","x12", - "x13","x14","x15", - "x16","x17","x18","x19", - "x20","x21","x22","x23", - "x24","x25","x26","x27", - "v0", "v1", "v2", "v3", - "v4", "v5", "v6", "v7", - "v8", "v9", "v10","v11", - "v12","v13","v14","v15", - "v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" -); - -} - - -/* - o 4x4 Double precision micro-kernel NOT fully functional yet. - o Runnable on ARMv8, compiled with aarch64 GCC. - o Use it together with the armv8 BLIS configuration. - o Tested on Juno board. Around 3 GFLOPS @ 1.1 GHz. - - December 2014. - - * UPDATE OCTOBER 2015: Now is fully functional. - * Tested on Juno board. Around 5.6 GFLOPS, 2 A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 4 GFLOPS, 4 A53 cores @ 850 MHz. - - * UPDATE NOVEMBER 2015 - * Micro-kernel changed to 6x8 - * Tested on Juno Board. Around 4 GFLOPS, 1 x A57 core @ 1.1 GHz. - * Tested on Juno Board. Around 7.6 GFLOPS, 2 x A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 1.5 GFLOPS, 1 x A53 core @ 850 MHz. - * Tested on Juno board. Around 5.5 GFLOPS, 4 x A53 cores @ 850 MHz. -*/ -void bli_dgemm_armv8a_asm_6x8 - ( - dim_t k0, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ - void* a_next = bli_auxinfo_next_a( data ); - void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - -__asm__ volatile -( -" \n\t" -" ldr x0,%[aaddr] \n\t" // Load address of A -" ldr x1,%[baddr] \n\t" // Load address of B -" ldr x2,%[caddr] \n\t" // Load address of C -" \n\t" -" ldr x3,%[a_next] \n\t" // Move pointer -" ldr x4,%[b_next] \n\t" // Move pointer -" \n\t" -" ldr x5,%[k_iter] \n\t" // Init guard (k_iter) -" ldr x6,%[k_left] \n\t" // Init guard (k_iter) -" \n\t" -" ldr x7,%[alpha] \n\t" // Alpha address -" ldr x8,%[beta] \n\t" // Beta address -" \n\t" -" ldr x9,%[cs_c] \n\t" // Load cs_c -" lsl x10,x9,#3 \n\t" // cs_c * sizeof(double) -" \n\t" -" ldr x13,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x13,#3 \n\t" // rs_c * sizeof(double). -" \n\t" -" add x20,x2,x10 \n\t" //Load address Column 1 of C -" add x21,x20,x10 \n\t" //Load address Column 2 of C -" add x22,x21,x10 \n\t" //Load address Column 3 of C -" add x23,x22,x10 \n\t" //Load address Column 4 of C -" add x24,x23,x10 \n\t" //Load address Column 5 of C -" add x25,x24,x10 \n\t" //Load address Column 6 of C -" add x26,x25,x10 \n\t" //Load address Column 7 of C -" \n\t" -" prfm pldl1keep,[x2] \n\t" // Prefetch c. -" prfm pldl1keep,[x20] \n\t" // Prefetch c. -" prfm pldl1keep,[x21] \n\t" // Prefetch c. -" prfm pldl1keep,[x22] \n\t" // Prefetch c. -" prfm pldl1keep,[x23] \n\t" // Prefetch c. -" prfm pldl1keep,[x24] \n\t" // Prefetch c. -" prfm pldl1keep,[x25] \n\t" // Prefetch c. -" prfm pldl1keep,[x26] \n\t" // Prefetch c. -" \n\t" -" ldr q0, [x0] \n\t" -" ldr q1, [x0, #16] \n\t" // Load a -" ldr q2, [x0, #32] \n\t" -" \n\t" -" ldr q3, [x1] \n\t" // Load b -" ldr q4, [x1, #16] \n\t" -" ldr q5, [x1, #32] \n\t" -" ldr q6, [x1, #48] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #256] \n\t" -" dup v9.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #320] \n\t" -" dup v10.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #384] \n\t" -" dup v11.2d, xzr \n\t" // Vector for accummulating column 1 -" prfm PLDL1KEEP, [x1, #448] \n\t" -" dup v12.2d, xzr \n\t" // Vector for accummulating column 1 -" dup v13.2d, xzr \n\t" // Vector for accummulating column 1 -" \n\t" -" dup v14.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #192] \n\t" -" dup v15.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #256] \n\t" -" dup v16.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #320] \n\t" -" dup v17.2d, xzr \n\t" // Vector for accummulating column 3 -" dup v18.2d, xzr \n\t" // Vector for accummulating column 3 -" dup v19.2d, xzr \n\t" // Vector for accummulating column 3 -" \n\t" -" dup v20.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v21.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v22.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v23.2d, xzr \n\t" // Vector for accummulating column 5 -" dup v24.2d, xzr \n\t" // Vector for accummulating column 5 -" dup v25.2d, xzr \n\t" // Vector for accummulating column 5 -" \n\t" -" dup v26.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v27.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v28.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v29.2d, xzr \n\t" // Vector for accummulating column 7 -" dup v30.2d, xzr \n\t" // Vector for accummulating column 7 -" dup v31.2d, xzr \n\t" // Vector for accummulating column 7 -" \n\t" -" \n\t" -" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -" beq .DCONSIDERKLEFT \n\t" -" \n\t" -"add x0, x0, #48 \n\t" //update address of A -"add x1, x1, #64 \n\t" //update address of B -" \n\t" -" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -" beq .DLASTITER \n\t" // (as loop is do-while-like). -" \n\t" -" DLOOP: \n\t" // Body -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #512] \n\t" -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #576] \n\t" -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q7, [x0, #32] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #16] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #32] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #16] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #48] \n\t" -" \n\t" // End it 1 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #640] \n\t" -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #336] \n\t" -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #400] \n\t" -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q2, [x0, #80] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #80] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #96] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #48] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #64] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #112] \n\t" -" \n\t" //End it 2 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #464] \n\t" -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1, #128] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q7, [x0, #128] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #144] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #160] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #96] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #112] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #176] \n\t" -" \n\t" // End it 3 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #192] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q2, [x0, #176] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #208] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #224] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #144] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #160] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #240] \n\t" -" \n\t" //End it 4 -" add x0, x0, #192 \n\t" -" add x1, x1, #256 \n\t" -" \n\t" -" sub x5,x5,1 \n\t" // i-=1 -" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -" bne DLOOP \n\t" -" \n\t" -".DLASTITER: \n\t" -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q7, [x0, #32] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #16] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #32] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #16] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #48] \n\t" -" \n\t" // End it 1 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q2, [x0, #80] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #80] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #96] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #48] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #64] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #112] \n\t" -" \n\t" //End it 2 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #128] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q7, [x0, #128] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #144] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #160] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #96] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #112] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #176] \n\t" -" \n\t" // End it 3 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" add x1, x1, #192 \n\t" -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" \n\t" //End it 4 -" add x0, x0, #144 \n\t" -" \n\t" -" .DCONSIDERKLEFT: \n\t" -" cmp x6,0 \n\t" // If k_left == 0, we are done. -" beq .DPOSTACCUM \n\t" // else, we enter the k_left loop. -" \n\t" -".DLOOPKLEFT: \n\t" -" \n\t" -" ldr q0, [x0],#16 \n\t" -" ldr q1, [x0],#16 \n\t" // Load a -" ldr q2, [x0],#16 \n\t" -" \n\t" -" ldr q3, [x1],#16 \n\t" // Load b -" ldr q4, [x1],#16 \n\t" -" ldr q5, [x1],#16 \n\t" -" ldr q6, [x1],#16 \n\t" -" \n\t" -" sub x6,x6,1 \n\t" -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" cmp x6,0 \n\t" // Iterate again. -" bne .DLOOPKLEFT \n\t" // if i!=0. -" \n\t" -" .DPOSTACCUM: \n\t" -" \n\t" -" ld1r {v6.2d},[x7] \n\t" // Load alpha. -" ld1r {v7.2d},[x8] \n\t" // Load beta -" \n\t" -" cmp x13,#1 \n\t" // If rs_c != 1 (column-major) -" bne .DGENSTORED \n\t" -" \n\t" -" .DCOLSTORED: \n\t" // C is column-major. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS1 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x2] \n\t" //Load column 0 of C -" ldr q1, [x2, #16] \n\t" -" ldr q2, [x2, #32] \n\t" -" \n\t" -" ldr q3, [x20] \n\t" //Load column 1 of C -" ldr q4, [x20, #16] \n\t" -" ldr q5, [x20, #32] \n\t" -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROCOLSTOREDS1: \n\t" -" \n\t" -" fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x2] \n\t" //Store column 0 of C -" str q1, [x2, #16] \n\t" -" str q2, [x2, #32] \n\t" -" \n\t" -" str q3, [x20] \n\t" //Store column 1 of C -" str q4, [x20, #16] \n\t" -" str q5, [x20, #32] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS2 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x21] \n\t" //Load column 2 of C -" ldr q9, [x21, #16] \n\t" -" ldr q10, [x21, #32] \n\t" -" \n\t" -" ldr q11, [x22] \n\t" //Load column 3 of C -" ldr q12, [x22, #16] \n\t" -" ldr q13, [x22, #32] \n\t" -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROCOLSTOREDS2: \n\t" -" \n\t" -" fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x21] \n\t" //Store column 2 of C -" str q9, [x21, #16] \n\t" -" str q10, [x21, #32] \n\t" -" \n\t" -" str q11, [x22] \n\t" //Store column 3 of C -" str q12, [x22, #16] \n\t" -" str q13, [x22, #32] \n\t" -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS3 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x23] \n\t" //Load column 4 of C -" ldr q1, [x23, #16] \n\t" -" ldr q2, [x23, #32] \n\t" -" \n\t" -" ldr q3, [x24] \n\t" //Load column 5 of C -" ldr q4, [x24, #16] \n\t" -" ldr q5, [x24, #32] \n\t" -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROCOLSTOREDS3: \n\t" -" \n\t" -" fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x23] \n\t" //Store column 4 of C -" str q1, [x23, #16] \n\t" -" str q2, [x23, #32] \n\t" -" \n\t" -" str q3, [x24] \n\t" //Store column 5 of C -" str q4, [x24, #16] \n\t" -" str q5, [x24, #32] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS4 \n\t" // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x25] \n\t" //Load column 6 of C -" ldr q9, [x25, #16] \n\t" -" ldr q10, [x25, #32] \n\t" -" \n\t" -" ldr q11, [x26] \n\t" //Load column 7 of C -" ldr q12, [x26, #16] \n\t" -" ldr q13, [x26, #32] \n\t" -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROCOLSTOREDS4: \n\t" -" \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" -" \n\t" -" fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x25] \n\t" //Store column 6 of C -" str q9, [x25, #16] \n\t" -" str q10, [x25, #32] \n\t" -" \n\t" -" str q11, [x26] \n\t" //Store column 7 of C -" str q12, [x26, #16] \n\t" -" str q13, [x26, #32] \n\t" -" \n\t" -" b .DEND \n\t" -" \n\t" -" .DGENSTORED: \n\t" // C is general-stride stored. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS1 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x2 \n\t" -" \n\t" // Load address of C. -" ld1 {v0.d}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.d}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v1.d}[0],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v1.d}[1],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v2.d}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v2.d}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" // Load address of C. -" \n\t" -" ld1 {v3.d}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v3.d}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v4.d}[0],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v4.d}[1],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v5.d}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v5.d}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROGENSTOREDS1: \n\t" -" \n\t" -" fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x2 \n\t" // Load address of C. -" \n\t" -" st1 {v0.d}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.d}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v1.d}[0],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v1.d}[1],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v2.d}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v2.d}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" // Load address of C. -" \n\t" -" st1 {v3.d}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v3.d}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v4.d}[0],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v4.d}[1],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v5.d}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v5.d}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS2 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x21 \n\t" // Load address of C. -" \n\t" -" ld1 {v8.d}[0], [x27],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v8.d}[1], [x27],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v9.d}[0], [x27],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v9.d}[1], [x27],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v10.d}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v10.d}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" // Load address of C. -" \n\t" -" ld1 {v11.d}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v11.d}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v12.d}[0],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v12.d}[1],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v13.d}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v13.d}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROGENSTOREDS2: \n\t" -" \n\t" -" fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x21 \n\t" // Load address of C. -" \n\t" -" st1 {v8.d}[0], [x27],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v8.d}[1], [x27],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v9.d}[0], [x27],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v9.d}[1], [x27],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v10.d}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v10.d}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" // Load address of C. -" \n\t" -" st1 {v11.d}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v11.d}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v12.d}[0],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v12.d}[1],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v13.d}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v13.d}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS3 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x23 \n\t" // Load address of C. -" \n\t" -" ld1 {v0.d}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v0.d}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v1.d}[0],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v1.d}[1],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v2.d}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v2.d}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. -" \n\t" -" mov x27, x24 \n\t" // Load address of C. -" \n\t" -" ld1 {v3.d}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v3.d}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v4.d}[0],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v4.d}[1],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v5.d}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v5.d}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROGENSTOREDS3: \n\t" -" \n\t" -" fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x23 \n\t" // Load address of C. -" \n\t" -" st1 {v0.d}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v0.d}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v1.d}[0],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v1.d}[1],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v2.d}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v2.d}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. -" \n\t" -" mov x27, x24 \n\t" // Load address of C. -" \n\t" -" st1 {v3.d}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v3.d}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v4.d}[0],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v4.d}[1],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v5.d}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v5.d}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS4 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" ld1 {v8.d}[0], [x27],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v8.d}[1], [x27],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v9.d}[0], [x27],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v9.d}[1], [x27],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v10.d}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v10.d}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" // Load address of C. -" \n\t" -" ld1 {v11.d}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v11.d}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v12.d}[0],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v12.d}[1],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v13.d}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v13.d}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -" .DBETAZEROGENSTOREDS4: \n\t" -" \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" -" \n\t" -" fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x25 \n\t" // Load address of C. -" \n\t" -" st1 {v8.d}[0], [x27],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v8.d}[1], [x27],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v9.d}[0], [x27],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v9.d}[1], [x27],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v10.d}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v10.d}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" // Load address of C. -" \n\t" -" st1 {v11.d}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v11.d}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v12.d}[0],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v12.d}[1],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. -" \n\t" -" .DEND: \n\t" // Done! -" \n\t" -:// output operands (none) -:// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 6 - [cs_c] "m" (cs_c), // 7 - [a_next] "m" (a_next), // 8 - [b_next] "m" (b_next) // 9 -:// Register clobber list - "x0","x1","x2","x3", - "x4","x5","x6", - "x7","x8","x9", - "x10","x11","x12","x13","x14","x16","x17", - "x20","x21","x22","x23","x24","x25","x26", - "x27", - "v0","v1","v2", - "v3","v4","v5", - "v6","v7","v8", - "v9","v10","v11", - "v12","v13","v14", - "v15","v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" -); - - - -} - - -#if 0 -void bli_cgemm_armv8a_opt_4x4 - ( - dim_t k, - scomplex* restrict alpha, - scomplex* restrict a, - scomplex* restrict b, - scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c, inc_t cs_c, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ -} - -void bli_zgemm_armv8a_opt_4x4 - ( - dim_t k, - dcomplex* restrict alpha, - dcomplex* restrict a, - dcomplex* restrict b, - dcomplex* restrict beta, - dcomplex* restrict c, inc_t rs_c, inc_t cs_c, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ -} - -#endif - diff --git a/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d4x4.c b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d4x4.c new file mode 100644 index 0000000000..0dbfbcf6b1 --- /dev/null +++ b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d4x4.c @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "armv8a_asm_utils.h" + +// Nanokernel operations. +#include "armv8a_asm_d2x2.h" + +#define DGEMM_4X4_MKER_LOOP_PLAIN(C00,C10,C01,C11,C02,C12,C03,C13,A0,A1,B0,B1) \ + DGEMM_2X2_NANOKERNEL(C00,C01,A0,B0) \ + DGEMM_2X2_NANOKERNEL(C10,C11,A1,B0) \ + DGEMM_2X2_NANOKERNEL(C02,C03,A0,B1) \ + DGEMM_2X2_NANOKERNEL(C12,C13,A1,B1) + +// For contiguous storage of C. +#define DLOADC_2V_C_FWD(C0,C1,CADDR,CSHIFT,LDC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" +#define DSTOREC_2V_C_FWD(C0,C1,CADDR,CSHIFT,LDC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" + +void bli_dgemm_armv8a_asm_4x4 + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, #4 \n\t" // Column-skip of A. +" mov x3, #4 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:7 ] <- C +// V[ 8:19] <- B +// V[20:31] <- A +// Under this scheme, the following is defined: +#define DGEMM_4X4_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1) \ + DGEMM_4X4_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,A0,A1,B0,B1) +// TODO: Prefetch C. +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q20, [x0, #16*0] \n\t" +" ldr q21, [x0, #16*1] \n\t" +" ldr q22, [x0, #16*2] \n\t" +" ldr q23, [x0, #16*3] \n\t" +" ldr q24, [x0, #16*4] \n\t" +" ldr q25, [x0, #16*5] \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" ldr q26, [x0, #16*0] \n\t" +" ldr q27, [x0, #16*1] \n\t" +" ldr q28, [x0, #16*2] \n\t" +" ldr q29, [x0, #16*3] \n\t" +" ldr q30, [x0, #16*4] \n\t" +" ldr q31, [x0, #16*5] \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q8, [x1, #16*0] \n\t" +" ldr q9, [x1, #16*1] \n\t" +" ldr q10, [x1, #16*2] \n\t" +" ldr q11, [x1, #16*3] \n\t" +" ldr q12, [x1, #16*4] \n\t" +" ldr q13, [x1, #16*5] \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" ldr q14, [x1, #16*0] \n\t" +" ldr q15, [x1, #16*1] \n\t" +" ldr q16, [x1, #16*2] \n\t" +" ldr q17, [x1, #16*3] \n\t" +" ldr q18, [x1, #16*4] \n\t" +" ldr q19, [x1, #16*5] \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,B0,B1) \ + DGEMM_4X4_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1) \ + "ldr q"#A0", [x0, #16*0] \n\t" \ + "ldr q"#A1", [x0, #16*1] \n\t" \ + "add x0, x0, x2 \n\t" \ + "ldr q"#B0", [x1, #16*0] \n\t" \ + "ldr q"#B1", [x1, #16*1] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(20,21,8,9) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,10,11) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(24,25,12,13) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,14,15) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(28,29,16,17) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(30,31,18,19) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(20,21,8,9) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(22,23,10,11) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(24,25,12,13) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(26,27,14,15) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(28,29,16,17) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(30,31,18,19) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q20, [x0, #16*0] \n\t" +" ldr q21, [x0, #16*1] \n\t" +" add x0, x0, x2 \n\t" +" ldr q8, [x1, #16*0] \n\t" +" ldr q9, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(20,21,8,9) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr d8, [x4] \n\t" // Load alpha & beta (value). +" ldr d9, [x8] \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, 64*0] \n\t" // Do not know cache line size, +" prfm PLDL1STRM, [x0, 64*1] \n\t" // issue some number of prfm instructions +" prfm PLDL1STRM, [x0, 64*2] \n\t" // to try to activate hardware prefetcher. +" prfm PLDL1STRM, [x1, 64*0] \n\t" +" prfm PLDL1STRM, [x1, 64*1] \n\t" +" prfm PLDL1STRM, [x1, 64*3] \n\t" +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" // Check for generic storage. +BNE(WRITE_MEM_G) +// +// Contiguous C-storage. +LABEL(WRITE_MEM_C) +DLOADC_2V_C_FWD(10,11,x9,0,x7) +DLOADC_2V_C_FWD(12,13,x9,0,x7) +DLOADC_2V_C_FWD(14,15,x9,0,x7) +DLOADC_2V_C_FWD(16,17,x9,0,x7) +DSCALE8V(10,11,12,13,14,15,16,17,9,0) +DSCALEA8V(10,11,12,13,14,15,16,17,0,1,2,3,4,5,6,7,8,0) +DSTOREC_2V_C_FWD(10,11,x5,0,x7) +DSTOREC_2V_C_FWD(12,13,x5,0,x7) +DSTOREC_2V_C_FWD(14,15,x5,0,x7) +DSTOREC_2V_C_FWD(16,17,x5,0,x7) +BRANCH(END_WRITE_MEM) +// +// Generic-strided C-storage. +LABEL(WRITE_MEM_G) +// TODO: Implement. +LABEL(END_WRITE_MEM) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + +} diff --git a/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d6x8r.c b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d6x8r.c new file mode 100644 index 0000000000..2fe83438f5 --- /dev/null +++ b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d6x8r.c @@ -0,0 +1,356 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" + +// Label locality & misc. +#include "armv8a_asm_utils.h" + +// Nanokernel operations. +#include "armv8a_asm_d2x2.h" + +/* Order of row-major DGEMM_6x8's execution in 2x2 blocks: + * + * +---+ +---+ +---+ +---+ + * | 0 | | 1 | | 6 | | 7 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 2 | | 3 | | 8 | | 9 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 4 | | 5 | | 10| | 11| + * +---+ +---+ +---+ +---+ + * + */ +#define DGEMM_6X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,B0,B1,B2,B3,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_2X2_NANOKERNEL(C40,C50,B0,A2) \ + DGEMM_2X2_NANOKERNEL(C41,C51,B1,A2) \ + DGEMM_LOAD2V_ ##LOADNEXT (B0,B1,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_LOAD1V_ ##LOADNEXT (A0,AADDR,ASHIFT) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (A1,AADDR,ASHIFT+16) \ + DGEMM_2X2_NANOKERNEL(C42,C52,B2,A2) \ + DGEMM_2X2_NANOKERNEL(C43,C53,B3,A2) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +// For contiguous storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DPRFMC_FWD(CADDR,RSC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For scattered storage of C. +#define DLOADC_GATHER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \ +" mov "#CELEM", "#CADDR" \n\t" \ + DLOAD1V_GATHER_ELMFWD(C0,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C1,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C2,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C3,CELEM,CSC) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DSTOREC_SCATTER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \ +" mov "#CELEM", "#CADDR" \n\t" \ + DSTORE1V_SCATTER_ELMFWD(C0,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C1,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C2,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C3,CELEM,CSC) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + + +void bli_dgemm_armv8a_asm_6x8r + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, #6 \n\t" // Column-skip of A. +" mov x3, #8 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x9, x5 \n\t" +" cmp x7, #8 \n\t" // Do not prefetch C for generic strided. +BNE(C_PREFETCH_END) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +LABEL(C_PREFETCH_END) +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_6X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q24, [x0, #16*0] \n\t" // Load A. +" ldr q25, [x0, #16*1] \n\t" +" ldr q26, [x0, #16*2] \n\t" +" add x0, x0, x2 \n\t" +" ldr q27, [x0, #16*0] \n\t" +" \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x0,1*16,x1,0,load) \ + "add x0, x0, x2 \n\t" \ + "ldr q"#A2", [x0, #16*0] \n\t" \ + "ldr q"#B2", [x1, #16*2] \n\t" \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(27,24,25,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(25,26,27,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(26,27,24,28,29,30,31,x0,1*16,x1,0,load) +" add x0, x0, x2 \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(25,26,27,28,29,30,31,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q24, [x0, #16*0] \n\t" // Load A col. +" ldr q25, [x0, #16*1] \n\t" +" ldr q26, [x0, #16*2] \n\t" +" add x0, x0, x2 \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B row. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v24.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v25.2d}, [x8] \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, 64*0] \n\t" // Do not know cache line size, +" prfm PLDL1STRM, [x0, 64*1] \n\t" // issue some number of prfm instructions +" prfm PLDL1STRM, [x0, 64*2] \n\t" // to try to activate hardware prefetcher. +" prfm PLDL1STRM, [x1, 64*0] \n\t" +" prfm PLDL1STRM, [x1, 64*1] \n\t" +" prfm PLDL1STRM, [x1, 64*3] \n\t" +" \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for generic storage. +BNE(WRITE_MEM_G) +// +// Contiguous C-storage. +LABEL(WRITE_MEM_R) +" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0. +" \n\t" // This conditional flag will be used +" \n\t" // multiple times for skipping load. +// Row 0: +BEQ(ZERO_BETA_R_0) +DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_0) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +// Row 1 & 2: +BEQ(ZERO_BETA_R_1_2) +DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) +DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_1_2) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +// Row 3 & 4 & 5: +BEQ(ZERO_BETA_R_3_4_5) +DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x9,0,x6) +DLOADC_4V_R_FWD(8,9,10,11,x9,0,x6) +DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0) +DSCALEA4V(20,21,22,23,8,9,10,11,25,0) +LABEL(ZERO_BETA_R_3_4_5) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// Generic-strided C-storage. +LABEL(WRITE_MEM_G) +" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0. +" \n\t" +// Row 0: +BEQ(ZERO_BETA_G_0) +DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +LABEL(ZERO_BETA_G_0) +DSTOREC_SCATTER_4V_R_FWD(0,1,2,3,x5,x1,x7,x6) +// Row 1 & 2: +BEQ(ZERO_BETA_G_1_2) +DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6) +DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_G_1_2) +DSTOREC_SCATTER_4V_R_FWD(4,5,6,7,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(8,9,10,11,x5,x1,x7,x6) +// Row 3 & 4 & 5: +BEQ(ZERO_BETA_G_3_4_5) +DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(4,5,6,7,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(8,9,10,11,x9,x0,x7,x6) +DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0) +DSCALEA4V(20,21,22,23,8,9,10,11,25,0) +LABEL(ZERO_BETA_G_3_4_5) +DSTOREC_SCATTER_4V_R_FWD(12,13,14,15,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(16,17,18,19,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(20,21,22,23,x5,x1,x7,x6) +LABEL(END_WRITE_MEM) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8","x9", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); +} + diff --git a/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d8x4.c b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d8x4.c new file mode 100644 index 0000000000..129c3613ac --- /dev/null +++ b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d8x4.c @@ -0,0 +1,294 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "armv8a_asm_utils.h" + +// Nanokernel operations. +#include "armv8a_asm_d2x2.h" + +/* Order of DGEMM_8x4's execution in 2x2 blocks: + * + * +---+ +---+ + * | 0 | | 2 | + * +---+ +---+ + * +---+ +---+ + * | 1 | | 3 | + * +---+ +---+ + * +---+ +---+ + * | 4 | | 6 | + * +---+ +---+ + * +---+ +---+ + * | 5 | | 7 | + * +---+ +---+ + * + */ +#define DGEMM_8X4_MKER_LOOP_PLAIN(C00,C10,C20,C30,C01,C11,C21,C31,C02,C12,C22,C32,C03,C13,C23,C33,A0,A1,A2,A3,B0,B1,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C01,A0,B0) \ + DGEMM_2X2_NANOKERNEL(C10,C11,A1,B0) \ + DGEMM_2X2_NANOKERNEL(C02,C03,A0,B1) \ + DGEMM_2X2_NANOKERNEL(C12,C13,A1,B1) \ + DGEMM_LOAD2V_ ##LOADNEXT (A0,A1,AADDR,ASHIFT) \ + DGEMM_2X2_NANOKERNEL(C20,C21,A2,B0) \ + DGEMM_2X2_NANOKERNEL(C30,C31,A3,B0) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C22,C23,A2,B1) \ + DGEMM_2X2_NANOKERNEL(C32,C33,A3,B1) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +// For contiguous storage of C. +#define DLOADC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" +#define DSTOREC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" + +void bli_dgemm_armv8a_asm_8x4 + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // This kernel is a WIP. + // I have no generic stride support at this moment. + assert( rs_c0 == 1 ); + // if ( rs_c0 != 1 ) return ; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, #8 \n\t" // Column-skip of A. +" mov x3, #4 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:21] <- B +// V[22:29] <- A +// Under this scheme, the following is defined: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_8X4_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,A2,A3,B0,B1,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) +// TODO: Prefetch C. +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q22, [x0, #16*0] \n\t" +" ldr q23, [x0, #16*1] \n\t" +" ldr q24, [x0, #16*2] \n\t" +" ldr q25, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +" ldr q26, [x0, #16*0] \n\t" +" ldr q27, [x0, #16*1] \n\t" +" ldr q28, [x0, #16*2] \n\t" +" ldr q29, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q16, [x1, #16*0] \n\t" +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" ldr q18, [x1, #16*0] \n\t" +" ldr q19, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" ldr q20, [x1, #16*0] \n\t" +" ldr q21, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,B0,B1) \ + DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,x0,0,x1,0,load) \ + "ldr q"#B1", [x1, #16*1] \n\t" \ + "ldr q"#A2", [x0, #16*2] \n\t" \ + "ldr q"#A3", [x0, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" \ + "add x0, x0, x2 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,24,25,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,24,25,20,21) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,24,25,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,20,21) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(26,27,28,29,16,17,x0,0,x1,0,noload) +" ldr q26, [x0, #16*0] \n\t" +" ldr q27, [x0, #16*1] \n\t" +" ldr q28, [x0, #16*2] \n\t" +" ldr q29, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(22,23,24,25,18,19,xzr,-1,xzr,-1,noload) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(26,27,28,29,20,21,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q22, [x0, #16*0] \n\t" // Load A col. +" ldr q23, [x0, #16*1] \n\t" +" ldr q24, [x0, #16*2] \n\t" +" ldr q25, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +" ldr q16, [x1, #16*0] \n\t" // Load B col. +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(22,23,24,25,16,17,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr d16, [x4] \n\t" // Load alpha & beta (value). +" ldr d17, [x8] \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, 64*0] \n\t" // Do not know cache line size, +" prfm PLDL1STRM, [x0, 64*1] \n\t" // issue some number of prfm instructions +" prfm PLDL1STRM, [x0, 64*2] \n\t" // to try to activate hardware prefetcher. +" prfm PLDL1STRM, [x1, 64*0] \n\t" +" prfm PLDL1STRM, [x1, 64*1] \n\t" +" prfm PLDL1STRM, [x1, 64*3] \n\t" +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" // Check for generic storage. +BNE(WRITE_MEM_G) +// +// Contiguous C-storage. +LABEL(WRITE_MEM_C) +DLOADC_4V_C_FWD(20,21,22,23,x9,0,x7) +DLOADC_4V_C_FWD(24,25,26,27,x9,0,x7) +DSCALE8V(20,21,22,23,24,25,26,27,17,0) +DSCALEA8V(20,21,22,23,24,25,26,27,0,1,2,3,4,5,6,7,16,0) +// +DLOADC_4V_C_FWD(0,1,2,3,x9,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x9,0,x7) +DSCALE8V(0,1,2,3,4,5,6,7,17,0) +DSCALEA8V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,0) +// +DSTOREC_4V_C_FWD(20,21,22,23,x5,0,x7) +DSTOREC_4V_C_FWD(24,25,26,27,x5,0,x7) +DSTOREC_4V_C_FWD(0,1,2,3,x5,0,x7) +DSTOREC_4V_C_FWD(4,5,6,7,x5,0,x7) +BRANCH(END_WRITE_MEM) +// +// Generic-strided C-storage. +LABEL(WRITE_MEM_G) +// TODO: Implement. +LABEL(END_WRITE_MEM) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_armv8a_ref.c b/kernels/armv8a/3/sup/bli_gemmsup_armv8a_ref.c new file mode 100644 index 0000000000..c87ff1feb6 --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_armv8a_ref.c @@ -0,0 +1,450 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// Separate instantiation for Armv8-A reference kernels. +// Temporary workaround. Will be removed after upstream has switched to a better way +// of exposing gemmsup interface. + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, _armv8a, _ref2 ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, _armv8a, _ref2 ) + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c new file mode 100644 index 0000000000..630459db73 --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c @@ -0,0 +1,509 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +#define DGEMM_3X1X2_NKER_SUBLOOP(C0,C1,C2,A0,A1,A2,B) \ +" fmla v"#C0".2d, v"#A0".2d, v"#B".2d \n\t" \ +" fmla v"#C1".2d, v"#A1".2d, v"#B".2d \n\t" \ +" fmla v"#C2".2d, v"#A2".2d, v"#B".2d \n\t" + +#define DGEMM_3X8X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C03,C04,C05,C06,C07,C10,C11,C12,C13,C14,C15,C16,C17,C20,C21,C22,C23,C24,C25,C26,C27,A0,A1,A2,B0,B1,B2,B3,BADDR,BELEMADDR,BELEMST,LOADNEXT) \ + /* Always load before forwarding to the next line. */ \ + DGEMM_3X1X2_NKER_SUBLOOP(C00,C10,C20,A0,A1,A2,B0) \ + DGEMM_LOAD1V_K_load(B0,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C01,C11,C21,A0,A1,A2,B1) \ + DGEMM_LOAD1V_K_load(B1,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C02,C12,C22,A0,A1,A2,B2) \ + DGEMM_LOAD1V_K_load(B2,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C03,C13,C23,A0,A1,A2,B3) \ + DGEMM_LOAD1V_K_load(B3,BELEMADDR,BELEMST) \ + \ +" add "#BADDR", "#BADDR", #16 \n\t" \ +" mov "#BELEMADDR", "#BADDR" \n\t" \ + DGEMM_3X1X2_NKER_SUBLOOP(C04,C14,C24,A0,A1,A2,B0) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B0,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C05,C15,C25,A0,A1,A2,B1) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B1,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C06,C16,C26,A0,A1,A2,B2) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B2,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C07,C17,C27,A0,A1,A2,B3) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B3,BELEMADDR,BELEMST) + +#define DGEMM_LOAD1V_K_noload(V,ELEMADDR,ELEMST) +#define DGEMM_LOAD1V_K_load(V,ELEMADDR,ELEMST) \ +" ldr q"#V", [ "#ELEMADDR" ] \n\t" \ +" add "#ELEMADDR", "#ELEMADDR", "#ELEMST" \n\t" + +// For row-storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DLOAD1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" ld1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DSTORE1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" st1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE12V(V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) \ + DSCALE4V(V8,V9,V10,V11,A,IDX) +#define DSCALEA12V(D0,D1,D2,D3,D4,D5,D6,D7,D8,D9,D10,D11,S0,S1,S2,S3,S4,S5,S6,S7,S8,S9,S10,S11,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) \ + DSCALEA4V(D8,D9,D10,D11,S8,S9,S10,S11,A,IDX) + +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +void bli_dgemmsup_rd_armv8a_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( n0 != 8 ) + { + if ( n0 < 8 ) + { + for ( ; n0 >= 4; n0 -= 4 ) + { + dim_t m = m0; + double *a_loc = a; + double *c_loc = c; + + for ( ; m >= 3; m -= 3 ) + { + bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conja, conjb, 3, 4, k0, + alpha, a_loc, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + a_loc += 3 * rs_a0; + c_loc += 3 * rs_c0; + } + + if ( m > 0 ) + { + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, m, 4, k0, + alpha, a_loc, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + } + b += 4 * cs_b0; + c += 4 * cs_c0; + } + + for ( ; m0 > 0; m0 -= 3 ) + { + dim_t m_loc = ( m0 < 3 ) ? m0 : 3; + + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, m_loc, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + + a += 3 * rs_a0; + c += 3 * rs_c0; + } + } + else + { + assert( FALSE ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t m_iter = m0 / 3; + int64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:26] <- A +// V[28:31] <- B +// V[ 27 ] <- Not used. +// Under this scheme, the following is defined: +#define DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,BADDR,BELEMADDR,BELEMST,LOADNEXT) \ + DGEMM_3X8X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,BADDR,BELEMADDR,BELEMST,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q28, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q29, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q30, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q31, [x11] \n\t" +" add x11, x11, x3 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q24, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q25, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q26, [x14] \n\t" +// " add x14, x14, x2 \n\t" +" add x0, x0, #16 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x1,x11,x3,load) \ + "mov x14, x0 \n\t" \ + "ldr q24, [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q25, [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q26, [x14] \n\t" \ + /*"add x14, x14, x2 \n\t"*/ \ + "add x0, x0, #16 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,x1,x11,x3,noload) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v1.2d \n\t" // Line 0. +" faddp v1.2d, v2.2d, v3.2d \n\t" +" faddp v2.2d, v4.2d, v5.2d \n\t" +" faddp v3.2d, v6.2d, v7.2d \n\t" +" faddp v4.2d, v8.2d, v9.2d \n\t" // Line 1. +" faddp v5.2d, v10.2d, v11.2d \n\t" +" faddp v6.2d, v12.2d, v13.2d \n\t" +" faddp v7.2d, v14.2d, v15.2d \n\t" +" faddp v8.2d, v16.2d, v17.2d \n\t" // Line 2. +" faddp v9.2d, v18.2d, v19.2d \n\t" +" faddp v10.2d, v20.2d, v21.2d \n\t" +" faddp v11.2d, v22.2d, v23.2d \n\t" +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" ld1 {v29.d}[1], [x11], x3 \n\t" +" ld1 {v30.d}[0], [x11], x3 \n\t" +" ld1 {v30.d}[1], [x11], x3 \n\t" +" ld1 {v31.d}[0], [x11], x3 \n\t" +" ld1 {v31.d}[1], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" fmla v0.2d, v28.2d, v24.d[0] \n\t" +" fmla v1.2d, v29.2d, v24.d[0] \n\t" +" fmla v2.2d, v30.2d, v24.d[0] \n\t" +" fmla v3.2d, v31.2d, v24.d[0] \n\t" +" fmla v4.2d, v28.2d, v24.d[1] \n\t" +" fmla v5.2d, v29.2d, v24.d[1] \n\t" +" fmla v6.2d, v30.2d, v24.d[1] \n\t" +" fmla v7.2d, v31.2d, v24.d[1] \n\t" +" fmla v8.2d, v28.2d, v25.d[0] \n\t" +" fmla v9.2d, v29.2d, v25.d[0] \n\t" +" fmla v10.2d, v30.2d, v25.d[0] \n\t" +" fmla v11.2d, v31.2d, v25.d[0] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" // Don't scale for unit alpha. +" fcmp d30, d28 \n\t" +BEQ(UNIT_ALPHA) +DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,30,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(12,13,14,15,x1,0,x6) +DLOADC_4V_R_FWD(16,17,18,19,x1,0,x6) +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,31,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" trn1 v12.2d, v0.2d, v4.2d \n\t" +" trn2 v13.2d, v0.2d, v4.2d \n\t" +" trn1 v14.2d, v1.2d, v5.2d \n\t" +" trn2 v15.2d, v1.2d, v5.2d \n\t" +" trn1 v16.2d, v2.2d, v6.2d \n\t" +" trn2 v17.2d, v2.2d, v6.2d \n\t" +" trn1 v18.2d, v3.2d, v7.2d \n\t" +" trn2 v19.2d, v3.2d, v7.2d \n\t" +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_C) +DLOADC_1V_1ELM_C_FWD(0,20,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(1,20,1,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(2,21,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(3,21,1,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(4,22,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(5,22,1,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(6,23,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(7,23,1,x1,0,x7) +DSCALEA12V(12,13,14,15,16,17,18,19,8,9,10,11,0,1,2,3,4,5,6,7,20,21,22,23,31,0) +LABEL(ZERO_BETA_C) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +DSTOREC_1V_1ELM_C_FWD(12,8,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(13,8,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(14,9,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(15,9,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(16,10,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(17,10,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(18,11,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(19,11,1,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #3 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" madd x10, x2, x8, x10 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + a = a + m_iter * 3 * rs_a; + c = c + m_iter * 3 * rs_c; + for ( ; m_left > 0; m_left -= 2 ) + { + dim_t m_loc = ( m_left < 2 ) ? m_left : 2; + + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, m_loc, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + a += 2 * rs_a0; + c += 2 * rs_c0; + } +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c new file mode 100644 index 0000000000..e13dd668ea --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c @@ -0,0 +1,586 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +#define DGEMM_1X4X2_NKER_SUBLOOP(C0,C1,C2,C3,A,B0,B1,B2,B3) \ +" fmla v"#C0".2d, v"#A".2d, v"#B0".2d \n\t" \ +" fmla v"#C1".2d, v"#A".2d, v"#B1".2d \n\t" \ +" fmla v"#C2".2d, v"#A".2d, v"#B2".2d \n\t" \ +" fmla v"#C3".2d, v"#A".2d, v"#B3".2d \n\t" + +#define DGEMM_6X4X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,A3,B0,B1,B2,B3,AADDR,AELEMADDR,AELEMST,LOADNEXT) \ + /* Always load before forwarding to the next line. */ \ + DGEMM_1X4X2_NKER_SUBLOOP(C00,C01,C02,C03,A0,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A0,AELEMADDR,AELEMST) \ + DGEMM_1X4X2_NKER_SUBLOOP(C10,C11,C12,C13,A1,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A1,AELEMADDR,AELEMST) \ +" add "#AADDR", "#AADDR", #16 \n\t" \ +" mov "#AELEMADDR", "#AADDR" \n\t" \ + DGEMM_1X4X2_NKER_SUBLOOP(C20,C21,C22,C23,A2,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A2,AELEMADDR,AELEMST) \ + DGEMM_1X4X2_NKER_SUBLOOP(C30,C31,C32,C33,A3,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A3,AELEMADDR,AELEMST) \ + \ + DGEMM_1X4X2_NKER_SUBLOOP(C40,C41,C42,C43,A0,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (A0,AELEMADDR,AELEMST) \ + DGEMM_1X4X2_NKER_SUBLOOP(C50,C51,C52,C53,A1,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (A1,AELEMADDR,AELEMST) + +#define DGEMM_LOAD1V_K_noload(V,ELEMADDR,ELEMST) +#define DGEMM_LOAD1V_K_load(V,ELEMADDR,ELEMST) \ +" ldr q"#V", [ "#ELEMADDR" ] \n\t" \ +" add "#ELEMADDR", "#ELEMADDR", "#ELEMST" \n\t" + +// For row-storage of C. +#define DLOADC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE12V(V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) \ + DSCALE4V(V8,V9,V10,V11,A,IDX) +#define DSCALEA12V(D0,D1,D2,D3,D4,D5,D6,D7,D8,D9,D10,D11,S0,S1,S2,S3,S4,S5,S6,S7,S8,S9,S10,S11,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) \ + DSCALEA4V(D8,D9,D10,D11,S8,S9,S10,S11,A,IDX) + +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +void bli_dgemmsup_rd_armv8a_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( m0 != 6 ) + { + if ( m0 < 6 ) + { + if ( m0 == 5 ) + { + // 3xk calls. + dim_t n = n0; + double *b_loc = b; + double *c_loc = c; + for ( ; n >= 4; n -= 4 ) + { + bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conja, conjb, 3, 4, k0, + alpha, a, rs_a0, cs_a0, b_loc, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + b_loc += 4 * cs_b0; + c_loc += 4 * cs_c0; + } + if ( n > 0 ) + { + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n, k0, + alpha, a, rs_a0, cs_a0, b_loc, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + } + a += 3 * rs_a0; + c += 3 * rs_c0; + + // 2xk calls. + for ( ; n0 > 0; n0 -= 8 ) + { + dim_t n_loc = ( n0 < 8 ) ? n0 : 8; + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, 2, n_loc, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b += 8 * cs_b0; + c += 8 * cs_c0; + } + return; + } + else if ( m0 == 4 ) + { + for ( ; n0 > 0; n0 -= 8 ) + { + dim_t n_loc = ( n0 < 8 ) ? n0 : 8; + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, 2, n_loc, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, 2, n_loc, k0, + alpha, a + 2 * rs_a0, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c + 2 * rs_c0, rs_c0, cs_c0, data, cntx + ); + b += 8 * cs_b0; + c += 8 * cs_c0; + } + } + else if ( m0 == 3 ) + { + for ( ; n0 >= 4; n0 -= 4 ) + { + bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conja, conjb, 3, 4, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b += 4 * cs_b0; + c += 4 * cs_c0; + } + if ( n0 > 0 ) + { + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + } + else // m0 == 2 or 1. + { + for ( ; n0 > 0; n0 -= 8 ) + { + dim_t n_loc = ( n0 < 8 ) ? n0 : 8; + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, m0, n_loc, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b += 8 * cs_b0; + c += 8 * cs_c0; + } + } + } + else + { + assert( FALSE ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t n_iter = n0 / 4; + int64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[b] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[n_iter] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x1, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x0, %[a] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,B2,B3,AADDR,AELEMADDR,AELEMST,LOADNEXT) \ + DGEMM_6X4X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,A3,B0,B1,B2,B3,AADDR,AELEMADDR,AELEMST,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q28, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q29, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q30, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q31, [x11] \n\t" +// " add x11, x11, x3 \n\t" +" add x1, x1, #16 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q24, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q25, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q26, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q27, [x14] \n\t" +" add x14, x14, x2 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,B0,B1,B2,B3) \ + DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,B2,B3,x0,x14,x2,load) \ + /* A already loaded and forwarded. Process B only. */ \ + "mov x11, x1 \n\t" \ + "ldr q28, [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q29, [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q30, [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q31, [x11] \n\t" \ + /*"add x11, x11, x3 \n\t"*/ \ + "add x1, x1, #16 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,25,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC(26,27,24,25,28,29,30,31,x0,x14,x2,noload) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v1.2d \n\t" // Line 0. +" faddp v1.2d, v2.2d, v3.2d \n\t" +" faddp v2.2d, v4.2d, v5.2d \n\t" // Line 1. +" faddp v3.2d, v6.2d, v7.2d \n\t" +" faddp v4.2d, v8.2d, v9.2d \n\t" // Line 2. +" faddp v5.2d, v10.2d, v11.2d \n\t" +" faddp v6.2d, v12.2d, v13.2d \n\t" // Line 3. +" faddp v7.2d, v14.2d, v15.2d \n\t" +" faddp v8.2d, v16.2d, v17.2d \n\t" // Line 4. +" faddp v9.2d, v18.2d, v19.2d \n\t" +" faddp v10.2d, v20.2d, v21.2d \n\t" // Line 5. +" faddp v11.2d, v22.2d, v23.2d \n\t" +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" ld1 {v29.d}[1], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" ld1 {v25.d}[1], [x14], x2 \n\t" +" ld1 {v26.d}[0], [x14], x2 \n\t" +" ld1 {v26.d}[1], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" fmla v0.2d, v28.2d, v24.d[0] \n\t" +" fmla v1.2d, v29.2d, v24.d[0] \n\t" +" fmla v2.2d, v28.2d, v24.d[1] \n\t" +" fmla v3.2d, v29.2d, v24.d[1] \n\t" +" fmla v4.2d, v28.2d, v25.d[0] \n\t" +" fmla v5.2d, v29.2d, v25.d[0] \n\t" +" fmla v6.2d, v28.2d, v25.d[1] \n\t" +" fmla v7.2d, v29.2d, v25.d[1] \n\t" +" fmla v8.2d, v28.2d, v26.d[0] \n\t" +" fmla v9.2d, v29.2d, v26.d[0] \n\t" +" fmla v10.2d, v28.2d, v26.d[1] \n\t" +" fmla v11.2d, v29.2d, v26.d[1] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" // Don't scale for unit alpha. +" fcmp d30, d28 \n\t" +BEQ(UNIT_ALPHA) +DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,30,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_R) +DLOADC_2V_R_FWD(12,13,x1,0,x6) +DLOADC_2V_R_FWD(14,15,x1,0,x6) +DLOADC_2V_R_FWD(16,17,x1,0,x6) +DLOADC_2V_R_FWD(18,19,x1,0,x6) +DLOADC_2V_R_FWD(20,21,x1,0,x6) +DLOADC_2V_R_FWD(22,23,x1,0,x6) +DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,31,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_2V_R_FWD(0,1,x5,0,x6) +DSTOREC_2V_R_FWD(2,3,x5,0,x6) +DSTOREC_2V_R_FWD(4,5,x5,0,x6) +DSTOREC_2V_R_FWD(6,7,x5,0,x6) +DSTOREC_2V_R_FWD(8,9,x5,0,x6) +DSTOREC_2V_R_FWD(10,11,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" trn1 v12.2d, v0.2d, v2.2d \n\t" +" trn1 v13.2d, v4.2d, v6.2d \n\t" +" trn1 v14.2d, v8.2d, v10.2d \n\t" +" trn2 v15.2d, v0.2d, v2.2d \n\t" +" trn2 v16.2d, v4.2d, v6.2d \n\t" +" trn2 v17.2d, v8.2d, v10.2d \n\t" +" trn1 v18.2d, v1.2d, v3.2d \n\t" +" trn1 v19.2d, v5.2d, v7.2d \n\t" +" trn1 v20.2d, v9.2d, v11.2d \n\t" +" trn2 v21.2d, v1.2d, v3.2d \n\t" +" trn2 v22.2d, v5.2d, v7.2d \n\t" +" trn2 v23.2d, v9.2d, v11.2d \n\t" +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_C) +DLOADC_3V_C_FWD(0,1,2,x1,0,x7) +DLOADC_3V_C_FWD(3,4,5,x1,0,x7) +DLOADC_3V_C_FWD(6,7,8,x1,0,x7) +DLOADC_3V_C_FWD(9,10,11,x1,0,x7) +DSCALEA12V(12,13,14,15,16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,8,9,10,11,31,0) +LABEL(ZERO_BETA_C) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +DSTOREC_3V_C_FWD(12,13,14,x5,0,x7) +DSTOREC_3V_C_FWD(15,16,17,x5,0,x7) +DSTOREC_3V_C_FWD(18,19,20,x5,0,x7) +DSTOREC_3V_C_FWD(21,22,23,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #4 \n\t" +" madd x13, x7, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" madd x10, x3, x8, x10 \n\t" // Forward B's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [n_iter] "m" (n_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + b = b + n_iter * 4 * cs_b; + c = c + n_iter * 4 * cs_c; + if ( n_left >= 3 ) + { + bli_dgemmsup_rd_armv8a_asm_6x3 + ( + conja, conjb, 6, 3, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b = b + 3 * cs_b; + c = c + 3 * cs_c; + n_left -= 3; + } + + if ( n_left ) + { + // n_left < 3; + // + // Slice in rows. + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + a = a + 3 * rs_a; + c = c + 3 * rs_c; + + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8m.c new file mode 100644 index 0000000000..16001a73ce --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8m.c @@ -0,0 +1,455 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* + * +---+ +---+ +---+ +---+ + * | 0 | | 2 | | 4 | | 6 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 1 | | 3 | | 5 | | 7 | + * +---+ +---+ +---+ +---+ + */ +#define DGEMM_4X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B1,BADDR,BSHIFT1) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B2,BADDR,BSHIFT2) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) + + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DLOADC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DLOAD2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DLOAD2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DSTORE2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DSTORE2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + + +/* + * 4x8 dgemmsup kernel with extending 1st dimension. + * + * Recommanded usage case: + * o 16 < (L1 cache latency) * (Num. FPU) < 25. + * o L1 cache has a bandwidth not too low (true in most cases). + * o (FMLA latency) * (Num. FPU) < 32 cycles (true in almost all cases). + */ +void bli_dgemmsup_rv_armv8a_asm_4x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Fixme: This uker has no dispatching for unalighed sizes. + // Currently it only serves as a dispatch target for other kernels + // and cannot be registered in configurations. + assert( n0 == 8 ); + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_a = bli_auxinfo_ps_a( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t m_iter = m0 / 4; + int64_t m_left = m0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x11, %[ps_a] \n\t" // Panel-skip of A. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_a +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:23] <- A; Allowed latency: 48 cycles / # of FPUs. +// V[24:31] <- B; Allowed latency: 28 cycles / # of FPUs. +// Under this scheme, the following is defined: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_4X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v16.d}[0], [x14], x9 \n\t" +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v18.d}[0], [x14], x9 \n\t" +" ld1 {v18.d}[1], [x14], x9 \n\t" +" ld1 {v19.d}[0], [x14], x9 \n\t" +" ld1 {v19.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q24, [x1, #16*0] \n\t" // Load B. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" ldr q28, [x1, #16*0] \n\t" +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,B0,B1,B2,B3) \ + DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,x1,0,16*1,16*2,load) \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A0".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A0".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[1], [x14], x9 \n\t" \ + "add x0, x0, x2 \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(16,17,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(18,19,28,29,30,31) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(20,21,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(22,23,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,x1,0,16*1,16*2,load) +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(18,19,28,29,30,31,x1,0,16*1,16*2,load) +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(20,21,24,25,26,27,xzr,-1,-1,-1,noload) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(22,23,28,29,30,31,xzr,-1,-1,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" // Load A col. +" ld1 {v16.d}[0], [x14], x9 \n\t" +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" ldr q24, [x1, #16*0] \n\t" // Load B row. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,xzr,-1,-1,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fcmp d17, #0.0 \n\t" +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(0,1,2,3,4,5,6,7,20,21,22,23,24,25,26,27,17,0) +// +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,20,21,22,23,24,25,26,27,17,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +// +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose. +" trn1 v16.2d, v0.2d, v4.2d \n\t" // Column 0. +" trn1 v17.2d, v8.2d, v12.2d \n\t" +" trn2 v18.2d, v0.2d, v4.2d \n\t" // Column 1. +" trn2 v19.2d, v8.2d, v12.2d \n\t" +" trn1 v20.2d, v1.2d, v5.2d \n\t" // Column 2. +" trn1 v21.2d, v9.2d, v13.2d \n\t" +" trn2 v22.2d, v1.2d, v5.2d \n\t" // Column 3. +" trn2 v23.2d, v9.2d, v13.2d \n\t" +" trn1 v24.2d, v2.2d, v6.2d \n\t" // Column 4. +" trn1 v25.2d, v10.2d, v14.2d \n\t" +" trn2 v26.2d, v2.2d, v6.2d \n\t" // Column 5. +" trn2 v27.2d, v10.2d, v14.2d \n\t" +" trn1 v28.2d, v3.2d, v7.2d \n\t" // Column 6. +" trn1 v29.2d, v11.2d, v15.2d \n\t" +" trn2 v30.2d, v3.2d, v7.2d \n\t" // Column 7. +" trn2 v31.2d, v11.2d, v15.2d \n\t" +" ld1r {v14.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v15.2d}, [x8] \n\t" +DSCALE8V(16,17,18,19,20,21,22,23,14,0) +DSCALE8V(24,25,26,27,28,29,30,31,14,0) +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,15,0) +// +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(24,25,26,27,28,29,30,31,0,1,2,3,4,5,6,7,15,0) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +// +DSTOREC_4V_C_FWD(16,17,18,19,x5,0,x7) +DSTOREC_4V_C_FWD(20,21,22,23,x5,0,x7) +DSTOREC_4V_C_FWD(24,25,26,27,x5,0,x7) +DSTOREC_4V_C_FWD(28,29,30,31,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #4 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + a = a + m_iter * ps_a; + c = c + m_iter * 4 * rs_c; + if ( m_left ) + { + bli_dgemmsup_r_armv8a_ref2 + ( + conja, conjb, m_left, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8n.c new file mode 100644 index 0000000000..43913cd38d --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8n.c @@ -0,0 +1,458 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* + * +---+ +---+ +---+ +---+ + * | 0 | | 2 | | 4 | | 6 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 1 | | 3 | | 5 | | 7 | + * +---+ +---+ +---+ +---+ + */ +#define DGEMM_4X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B1,BADDR,BSHIFT1) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B2,BADDR,BSHIFT2) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) + + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DLOADC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DLOAD2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DLOAD2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DSTORE2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DSTORE2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + + +/* + * 4x8 dgemmsup kernel with extending 2nd dimension. + * + * Recommanded usage case: + * o 16 < (L1 cache latency) * (Num. FPU) < 25. + * o L1 cache has a bandwidth not too low (true in most cases). + * o (FMLA latency) * (Num. FPU) < 32 cycles (true in almost all cases). + */ +void bli_dgemmsup_rv_armv8a_asm_4x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Fixme: This uker has no dispatching for unalighed sizes. + // Currently it only serves as a dispatch target for other kernels + // and cannot be registered in configurations. + assert( m0 == 4 ); + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_b = bli_auxinfo_ps_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t n_iter = n0 / 8; + int64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[b] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[n_iter] \n\t" +" ldr x11, %[ps_b] \n\t" // Panel-skip of B. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_b +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x1, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x0, %[a] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:23] <- A; Allowed latency: 48 cycles / # of FPUs. +// V[24:31] <- B; Allowed latency: 28 cycles / # of FPUs. +// Under this scheme, the following is defined: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_4X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q24, [x1, #16*0] \n\t" // Load B first. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" ldr q28, [x1, #16*0] \n\t" +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v16.d}[0], [x14], x9 \n\t" // We want A to be kept in L1. +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v18.d}[0], [x14], x9 \n\t" +" ld1 {v18.d}[1], [x14], x9 \n\t" +" ld1 {v19.d}[0], [x14], x9 \n\t" +" ld1 {v19.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,B0,B1,B2,B3) \ + DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,x1,0,16*1,16*2,load) \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A0".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A0".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[1], [x14], x9 \n\t" \ + "add x0, x0, x2 \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(16,17,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(18,19,28,29,30,31) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(20,21,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(22,23,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,x1,0,16*1,16*2,load) +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(18,19,28,29,30,31,x1,0,16*1,16*2,load) +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(20,21,24,25,26,27,xzr,-1,-1,-1,noload) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(22,23,28,29,30,31,xzr,-1,-1,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q24, [x1, #16*0] \n\t" // Load B row. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" mov x14, x0 \n\t" // Load A col. +" ld1 {v16.d}[0], [x14], x9 \n\t" +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,xzr,-1,-1,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fcmp d17, #0.0 \n\t" +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(0,1,2,3,4,5,6,7,20,21,22,23,24,25,26,27,17,0) +// +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,20,21,22,23,24,25,26,27,17,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +// +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose. +" trn1 v16.2d, v0.2d, v4.2d \n\t" // Column 0. +" trn1 v17.2d, v8.2d, v12.2d \n\t" +" trn2 v18.2d, v0.2d, v4.2d \n\t" // Column 1. +" trn2 v19.2d, v8.2d, v12.2d \n\t" +" trn1 v20.2d, v1.2d, v5.2d \n\t" // Column 2. +" trn1 v21.2d, v9.2d, v13.2d \n\t" +" trn2 v22.2d, v1.2d, v5.2d \n\t" // Column 3. +" trn2 v23.2d, v9.2d, v13.2d \n\t" +" trn1 v24.2d, v2.2d, v6.2d \n\t" // Column 4. +" trn1 v25.2d, v10.2d, v14.2d \n\t" +" trn2 v26.2d, v2.2d, v6.2d \n\t" // Column 5. +" trn2 v27.2d, v10.2d, v14.2d \n\t" +" trn1 v28.2d, v3.2d, v7.2d \n\t" // Column 6. +" trn1 v29.2d, v11.2d, v15.2d \n\t" +" trn2 v30.2d, v3.2d, v7.2d \n\t" // Column 7. +" trn2 v31.2d, v11.2d, v15.2d \n\t" +" ld1r {v14.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v15.2d}, [x8] \n\t" +DSCALE8V(16,17,18,19,20,21,22,23,14,0) +DSCALE8V(24,25,26,27,28,29,30,31,14,0) +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,15,0) +// +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(24,25,26,27,28,29,30,31,0,1,2,3,4,5,6,7,15,0) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +// +DSTOREC_4V_C_FWD(16,17,18,19,x5,0,x7) +DSTOREC_4V_C_FWD(20,21,22,23,x5,0,x7) +DSTOREC_4V_C_FWD(24,25,26,27,x5,0,x7) +DSTOREC_4V_C_FWD(28,29,30,31,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #8 \n\t" +" madd x13, x7, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward B's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_b] "m" (ps_b), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [n_iter] "m" (n_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + b = b + n_iter * ps_b; + c = c + n_iter * 8 * cs_c; + if ( n_left ) + { + auxinfo_t data_d6x4mn = *data; + bli_auxinfo_set_ps_b( 4 * cs_b0, &data_d6x4mn ); + + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, 4, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d6x4mn, cntx + ); + } + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8m.c new file mode 100644 index 0000000000..3100112d3f --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8m.c @@ -0,0 +1,575 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* Order of row-major DGEMM_6x8's execution in 2x2 blocks: + * + * +---+ +---+ +---+ +---+ + * | 0 | | 1 | | 6 | | 7 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 2 | | 3 | | 8 | | 9 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 4 | | 5 | | 10| | 11| + * +---+ +---+ +---+ +---+ + * + */ +#define DGEMM_6X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_2X2_NANOKERNEL(C40,C50,B0,A2) \ + DGEMM_2X2_NANOKERNEL(C41,C51,B1,A2) \ + DGEMM_LOAD2V_ ##LOADNEXT (B0,B1,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A0,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A1,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C42,C52,B2,A2) \ + DGEMM_2X2_NANOKERNEL(C43,C53,B3,A2) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +#define DGEMM_LOAD1V_G_noload(V1,ADDR,ST) +#define DGEMM_LOAD1V_G_load(V1,ADDR,ST) \ +" ld1 {v"#V1".d}[0], ["#ADDR"], "#ST" \n\t" \ +" ld1 {v"#V1".d}[1], ["#ADDR"], "#ST" \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +// For row-storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE6V(V0,V1,V2,V3,V4,V5,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V4,V5,A,IDX) +#define DSCALEA6V(D0,D1,D2,D3,D4,D5,S0,S1,S2,S3,S4,S5,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D4,D5,S4,S5,A,IDX) + + +/* + * 6x8 dgemmsup kernel with extending 1st dimension. + * + * Recommanded usage case: (L1 cache latency) * (Num. FPU) < 17 cycles. + * + * Calls 4x8 for edge cases. + */ +void bli_dgemmsup_rv_armv8a_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( n0 != 8 ) + { + if ( n0 < 8 ) + { + for ( ; n0 >= 4; n0 -= 4 ) + { + dgemmsup_ker_ft ukr_fp; + auxinfo_t data_d8xkm = *data; + if ( bli_auxinfo_ps_a( data ) == 6 * rs_a0 ) + { + // Use 8x4 Asm kernel for the unpacked case. + bli_auxinfo_set_ps_a( 8 * rs_a0, &data_d8xkm ); + ukr_fp = bli_dgemmsup_rv_armv8a_asm_8x4m; + } + else + { + // Cannot change dimension for m when A is packed. + ukr_fp = bli_dgemmsup_rv_armv8a_int_6x4mn; + } + + ukr_fp + ( + conja, conjb, m0, 4, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d8xkm, cntx + ); + b += 4 * cs_b0; + c += 4 * cs_c0; + } + if ( n0 > 0 ) + { + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + } + else + { + assert( FALSE ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_a = bli_auxinfo_ps_a( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t m_iter = m0 / 6; + int64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x11, %[ps_a] \n\t" // Panel-skip of A. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_a +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_6X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v24.d}[0], [x14], x9 \n\t" +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v27.d}[0], [x14], x9 \n\t" +" ld1 {v27.d}[1], [x14], x9 \n\t" +" \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x14,x9,x1,0,load) \ + "add x0, x0, x2 \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A2".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[1], [x14], x9 \n\t" \ + "ldr q"#B2", [x1, #16*2] \n\t" \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(27,24,25,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(25,26,27,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(26,27,24,28,29,30,31,x14,x9,x1,0,load) +" add x0, x0, x2 \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(25,26,27,28,29,30,31,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" +" ld1 {v24.d}[0], [x14], x9 \n\t" // Load A col. +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B row. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v24.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v25.2d}, [x8] \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA_R) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA_R) +" fcmp d25, #0.0 \n\t" +BEQ(ZERO_BETA_R_1) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(4,5,6,7,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_1) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +BEQ(ZERO_BETA_R_2) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,26,27,28,29,0,1,2,3,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(16,17,18,19,20,21,22,23,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose, +// do transposition in row-order. +" trn1 v24.2d, v0.2d, v4.2d \n\t" // Row 0-1. +" trn2 v25.2d, v0.2d, v4.2d \n\t" +" trn1 v26.2d, v1.2d, v5.2d \n\t" +" trn2 v27.2d, v1.2d, v5.2d \n\t" +" trn1 v28.2d, v2.2d, v6.2d \n\t" +" trn2 v29.2d, v2.2d, v6.2d \n\t" +" trn1 v30.2d, v3.2d, v7.2d \n\t" +" trn2 v31.2d, v3.2d, v7.2d \n\t" +" \n\t" +" trn1 v0.2d, v8.2d, v12.2d \n\t" // Row 2-3. +" trn2 v1.2d, v8.2d, v12.2d \n\t" +" trn1 v2.2d, v9.2d, v13.2d \n\t" +" trn2 v3.2d, v9.2d, v13.2d \n\t" +" trn1 v4.2d, v10.2d, v14.2d \n\t" +" trn2 v5.2d, v10.2d, v14.2d \n\t" +" trn1 v6.2d, v11.2d, v15.2d \n\t" +" trn2 v7.2d, v11.2d, v15.2d \n\t" +" \n\t" +" trn1 v8.2d, v16.2d, v20.2d \n\t" // Row 4-5. +" trn2 v9.2d, v16.2d, v20.2d \n\t" +" trn1 v10.2d, v17.2d, v21.2d \n\t" // AMARI +" trn2 v11.2d, v17.2d, v21.2d \n\t" // AMARI +" trn1 v12.2d, v18.2d, v22.2d \n\t" // AMARI +" trn2 v13.2d, v18.2d, v22.2d \n\t" // AMARI +" trn1 v14.2d, v19.2d, v23.2d \n\t" // AMARI +" trn2 v15.2d, v19.2d, v23.2d \n\t" // AMARI +" \n\t" +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fmov d18, #1.0 \n\t" +" fcmp d16, d18 \n\t" +BEQ(UNIT_ALPHA_C) +DSCALE8V(24,25,26,27,28,29,30,31,16,0) +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +LABEL(UNIT_ALPHA_C) +" fcmp d17, #0.0 \n\t" +BEQ(ZERO_BETA_C_1) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(24,0,8,25,1,9,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_1) +DSTOREC_3V_C_FWD(24,0,8,x5,0,x7) +DSTOREC_3V_C_FWD(25,1,9,x5,0,x7) +BEQ(ZERO_BETA_C_2) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DLOADC_3V_C_FWD(24,0,8,x1,0,x7) +DLOADC_3V_C_FWD(25,1,9,x1,0,x7) +DSCALEA6V(26,2,10,27,3,11,18,19,20,21,22,23,17,0) +DSCALEA6V(28,4,12,29,5,13,24,0,8,25,1,9,17,0) +LABEL(ZERO_BETA_C_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +" fcmp d17, #0.0 \n\t" // Not the end. Reset branching reg. +#endif +DSTOREC_3V_C_FWD(26,2,10,x5,0,x7) +DSTOREC_3V_C_FWD(27,3,11,x5,0,x7) +BEQ(ZERO_BETA_C_3) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(30,6,14,31,7,15,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_3) +DSTOREC_3V_C_FWD(28,4,12,x5,0,x7) +DSTOREC_3V_C_FWD(29,5,13,x5,0,x7) +DSTOREC_3V_C_FWD(30,6,14,x5,0,x7) +DSTOREC_3V_C_FWD(31,7,15,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #6 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // Forward address. + a = a + m_iter * ps_a; + c = c + m_iter * 6 * rs_c; +#if 1 + auxinfo_t data_d6x4mn = *data; + bli_auxinfo_set_ps_b( 4 * cs_b0, &data_d6x4mn ); + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, m_left, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d6x4mn, cntx + ); +#else + if ( m_left >= 4 ) + { + // Calls 4x8m with only 1 outermost loop. + // As only 1 outermost loop is called, + // ps_a needs not being set here. + // + bli_dgemmsup_rv_armv8a_asm_4x8m + ( + conja, conjb, 4, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + m_left -= 4; + a = a + 4 * rs_a; + c = c + 4 * rs_c; + } + if ( m_left ) + { + bli_dgemmsup_r_armv8a_ref2 + ( + conja, conjb, m_left, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } +#endif + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8n.c new file mode 100644 index 0000000000..fb9357c11e --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8n.c @@ -0,0 +1,539 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* Order of row-major DGEMM_6x8's execution in 2x2 blocks: + * + * +---+ +---+ +---+ +---+ + * | 0 | | 1 | | 6 | | 7 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 2 | | 3 | | 8 | | 9 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 4 | | 5 | | 10| | 11| + * +---+ +---+ +---+ +---+ + * + */ +#define DGEMM_6X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_2X2_NANOKERNEL(C40,C50,B0,A2) \ + DGEMM_2X2_NANOKERNEL(C41,C51,B1,A2) \ + DGEMM_LOAD2V_ ##LOADNEXT (B0,B1,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A0,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A1,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C42,C52,B2,A2) \ + DGEMM_2X2_NANOKERNEL(C43,C53,B3,A2) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +#define DGEMM_LOAD1V_G_noload(V1,ADDR,ST) +#define DGEMM_LOAD1V_G_load(V1,ADDR,ST) \ +" ld1 {v"#V1".d}[0], ["#ADDR"], "#ST" \n\t" \ +" ld1 {v"#V1".d}[1], ["#ADDR"], "#ST" \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +// For row-storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE6V(V0,V1,V2,V3,V4,V5,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V4,V5,A,IDX) +#define DSCALEA6V(D0,D1,D2,D3,D4,D5,S0,S1,S2,S3,S4,S5,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D4,D5,S4,S5,A,IDX) + + +/* + * 6x8 dgemmsup kernel with extending 2nd dimension. + * + * Recommanded usage case: (L1 cache latency) * (Num. FPU) < 17 cycles. + * + * Calls 4x8n for edge cases. + */ +void bli_dgemmsup_rv_armv8a_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( m0 != 6 ) + { + // 5 = 4 + 1; + // 4; + // + while ( m0 >= 4 ) + { + bli_dgemmsup_rv_armv8a_asm_4x8n + ( + conja, conjb, 4, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + m0 -= 4; + a += 4 * rs_a0; + c += 4 * rs_c0; + } + + // 3, 2, 1; + // + if ( m0 > 0 ) + { + bli_dgemmsup_rv_armv8a_int_3x8mn + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_b = bli_auxinfo_ps_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t n_iter = n0 / 8; + int64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[b] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[n_iter] \n\t" +" ldr x11, %[ps_b] \n\t" // Panel-skip of B. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_b +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x1, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x0, %[a] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_6X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B first. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v24.d}[0], [x14], x9 \n\t" // We want A to be kept in L1. +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v27.d}[0], [x14], x9 \n\t" +" ld1 {v27.d}[1], [x14], x9 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x14,x9,x1,0,load) \ + "add x0, x0, x2 \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A2".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[1], [x14], x9 \n\t" \ + "ldr q"#B2", [x1, #16*2] \n\t" \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(27,24,25,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(25,26,27,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(26,27,24,28,29,30,31,x14,x9,x1,0,load) +" add x0, x0, x2 \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(25,26,27,28,29,30,31,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q28, [x1, #16*0] \n\t" // Load B row. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" mov x14, x0 \n\t" +" ld1 {v24.d}[0], [x14], x9 \n\t" // Load A col. +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v24.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v25.2d}, [x8] \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA_R) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA_R) +" fcmp d25, #0.0 \n\t" +BEQ(ZERO_BETA_R_1) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(4,5,6,7,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_1) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +BEQ(ZERO_BETA_R_2) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,26,27,28,29,0,1,2,3,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(16,17,18,19,20,21,22,23,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose, +// do transposition in row-order. +" trn1 v24.2d, v0.2d, v4.2d \n\t" // Row 0-1. +" trn2 v25.2d, v0.2d, v4.2d \n\t" +" trn1 v26.2d, v1.2d, v5.2d \n\t" +" trn2 v27.2d, v1.2d, v5.2d \n\t" +" trn1 v28.2d, v2.2d, v6.2d \n\t" +" trn2 v29.2d, v2.2d, v6.2d \n\t" +" trn1 v30.2d, v3.2d, v7.2d \n\t" +" trn2 v31.2d, v3.2d, v7.2d \n\t" +" \n\t" +" trn1 v0.2d, v8.2d, v12.2d \n\t" // Row 2-3. +" trn2 v1.2d, v8.2d, v12.2d \n\t" +" trn1 v2.2d, v9.2d, v13.2d \n\t" +" trn2 v3.2d, v9.2d, v13.2d \n\t" +" trn1 v4.2d, v10.2d, v14.2d \n\t" +" trn2 v5.2d, v10.2d, v14.2d \n\t" +" trn1 v6.2d, v11.2d, v15.2d \n\t" +" trn2 v7.2d, v11.2d, v15.2d \n\t" +" \n\t" +" trn1 v8.2d, v16.2d, v20.2d \n\t" // Row 4-5. +" trn2 v9.2d, v16.2d, v20.2d \n\t" +" trn1 v10.2d, v17.2d, v21.2d \n\t" // AMARI +" trn2 v11.2d, v17.2d, v21.2d \n\t" // AMARI +" trn1 v12.2d, v18.2d, v22.2d \n\t" // AMARI +" trn2 v13.2d, v18.2d, v22.2d \n\t" // AMARI +" trn1 v14.2d, v19.2d, v23.2d \n\t" // AMARI +" trn2 v15.2d, v19.2d, v23.2d \n\t" // AMARI +" \n\t" +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fmov d18, #1.0 \n\t" +" fcmp d16, d18 \n\t" +BEQ(UNIT_ALPHA_C) +DSCALE8V(24,25,26,27,28,29,30,31,16,0) +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +LABEL(UNIT_ALPHA_C) +" fcmp d17, #0.0 \n\t" +BEQ(ZERO_BETA_C_1) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(24,0,8,25,1,9,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_1) +DSTOREC_3V_C_FWD(24,0,8,x5,0,x7) +DSTOREC_3V_C_FWD(25,1,9,x5,0,x7) +BEQ(ZERO_BETA_C_2) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DLOADC_3V_C_FWD(24,0,8,x1,0,x7) +DLOADC_3V_C_FWD(25,1,9,x1,0,x7) +DSCALEA6V(26,2,10,27,3,11,18,19,20,21,22,23,17,0) +DSCALEA6V(28,4,12,29,5,13,24,0,8,25,1,9,17,0) +LABEL(ZERO_BETA_C_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +" fcmp d17, #0.0 \n\t" // Not the end. Reset branching reg. +#endif +DSTOREC_3V_C_FWD(26,2,10,x5,0,x7) +DSTOREC_3V_C_FWD(27,3,11,x5,0,x7) +BEQ(ZERO_BETA_C_3) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(30,6,14,31,7,15,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_3) +DSTOREC_3V_C_FWD(28,4,12,x5,0,x7) +DSTOREC_3V_C_FWD(29,5,13,x5,0,x7) +DSTOREC_3V_C_FWD(30,6,14,x5,0,x7) +DSTOREC_3V_C_FWD(31,7,15,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #8 \n\t" +" madd x13, x7, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward B's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_b] "m" (ps_b), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [n_iter] "m" (n_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // Forward address. + b = b + n_iter * ps_b; + c = c + n_iter * 8 * cs_c; + if ( n_left ) + { + // Set panel stride to unpacked mode. + // Only 1 millikernel w.r.t. 6x8 is executed. + auxinfo_t data_d6x4mn = *data; + bli_auxinfo_set_ps_b( 4 * cs_b0, &data_d6x4mn ); + // + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, 6, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d6x4mn, cntx + ); + } + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d8x4m.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d8x4m.c new file mode 100644 index 0000000000..5b0e9b062f --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d8x4m.c @@ -0,0 +1,431 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* + * +---+ +---+ + * | 0 | | 4 | + * +---+ +---+ + * +---+ +---+ + * | 1 | | 5 | + * +---+ +---+ + * +---+ +---+ + * | 2 | | 6 | + * +---+ +---+ + * +---+ +---+ + * | 3 | | 7 | + * +---+ +---+ + * + */ +#define DGEMM_8X4_MKER_LOOP_PLAIN(C00,C10,C20,C30,C01,C11,C21,C31,C02,C12,C22,C32,C03,C13,C23,C33,A0,A1,A2,A3,B0,B1,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C01,A0,B0) \ + DGEMM_2X2_NANOKERNEL(C10,C11,A1,B0) \ + DGEMM_2X2_NANOKERNEL(C20,C21,A2,B0) \ + DGEMM_2X2_NANOKERNEL(C30,C31,A3,B0) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C03,A0,B1) \ + DGEMM_2X2_NANOKERNEL(C12,C13,A1,B1) \ + DGEMM_2X2_NANOKERNEL(C22,C23,A2,B1) \ + DGEMM_2X2_NANOKERNEL(C32,C33,A3,B1) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DLOADC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" +#define DSTOREC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" + +#define DLOADC_4V_R_FWD(C00,C01,C10,C11,CADDR,CSHIFT,RSC) \ + DLOAD2V(C00,C01,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" \ + DLOAD2V(C10,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C00,C01,C10,C11,CADDR,CSHIFT,RSC) \ + DSTORE2V(C00,C01,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" \ + DSTORE2V(C10,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +/* + * 8x4 kernel for dgemmsup. + * + * R-dimension too short. + * Not recommanded for use. + */ +void bli_dgemmsup_rv_armv8a_asm_8x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Fixme: This uker has no dispatching for unalighed sizes. + // Currently it only serves as a dispatch target for other kernels + // and cannot be registered in configurations. + assert( n0 == 4 ); + + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + uint64_t ps_a = bli_auxinfo_ps_a( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + + int64_t m_iter = m0 / 8; + int64_t m_left = m0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x11, %[ps_a] \n\t" // Panel-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_a +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:19] <- B; Allowed latency: 24 cycles / # of FPUs. +// V[20:31] <- A; Allowed latency: 32 cycles / # of FPUs. +// Under this scheme, the following is defined: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_8X4_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,A2,A3,B0,B1,BADDR,BSHIFT,LOADNEXT) +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v24.d}[0], [x14], x9 \n\t" +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" ld1 {v27.d}[0], [x14], x9 \n\t" +" ld1 {v27.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v28.d}[0], [x14], x9 \n\t" +" ld1 {v28.d}[1], [x14], x9 \n\t" +" ld1 {v29.d}[0], [x14], x9 \n\t" +" ld1 {v29.d}[1], [x14], x9 \n\t" +" ld1 {v30.d}[0], [x14], x9 \n\t" +" ld1 {v30.d}[1], [x14], x9 \n\t" +" ld1 {v31.d}[0], [x14], x9 \n\t" +" ld1 {v31.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q16, [x1, #16*0] \n\t" +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" ldr q18, [x1, #16*0] \n\t" +" ldr q19, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,B0,B1) \ + DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,x1,0,load) \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A0".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A0".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A3".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A3".d}[1], [x14], x9 \n\t" \ + "ldr q"#B1", [x1, #16*1] \n\t" \ + "add x1, x1, x3 \n\t" \ + "add x0, x0, x2 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(20,21,22,23,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(28,29,30,31,16,17) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(20,21,22,23,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(28,29,30,31,18,19) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(20,21,22,23,18,19,x1,0,load) +" ldr q19, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(24,25,26,27,16,17,xzr,-1,noload) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(28,29,30,31,18,19,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" // Load A col. +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" ldr q16, [x1, #16*0] \n\t" // Load B col. +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(20,21,22,23,16,17,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v17.2d}, [x8] \n\t" +" fmov d18, #1.0 \n\t" +" fcmp d16, d18 \n\t" +BEQ(UNIT_ALPHA) +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #8 \n\t" // Check for row-storage. +BNE(WRITE_MEM_R) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" fcmp d17, #0.0 \n\t" +BEQ(ZERO_BETA_C) +DLOADC_4V_C_FWD(20,21,22,23,x1,0,x7) +DLOADC_4V_C_FWD(24,25,26,27,x1,0,x7) +DSCALEA8V(0,1,2,3,4,5,6,7,20,21,22,23,24,25,26,27,17,0) +// +DLOADC_4V_C_FWD(20,21,22,23,x1,0,x7) +DLOADC_4V_C_FWD(24,25,26,27,x1,0,x7) +DSCALEA8V(8,9,10,11,12,13,14,15,20,21,22,23,24,25,26,27,17,0) +LABEL(ZERO_BETA_C) +// +DSTOREC_4V_C_FWD(0,1,2,3,x5,0,x7) +DSTOREC_4V_C_FWD(4,5,6,7,x5,0,x7) +DSTOREC_4V_C_FWD(8,9,10,11,x5,0,x7) +DSTOREC_4V_C_FWD(12,13,14,15,x5,0,x7) +BRANCH(END_WRITE_MEM) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +// In-register transpose. +" trn1 v16.2d, v0.2d, v4.2d \n\t" // Row 0. +" trn1 v17.2d, v8.2d, v12.2d \n\t" +" trn2 v18.2d, v0.2d, v4.2d \n\t" // Row 1. +" trn2 v19.2d, v8.2d, v12.2d \n\t" +" trn1 v20.2d, v1.2d, v5.2d \n\t" // Row 2. +" trn1 v21.2d, v9.2d, v13.2d \n\t" +" trn2 v22.2d, v1.2d, v5.2d \n\t" // Row 3. +" trn2 v23.2d, v9.2d, v13.2d \n\t" +" trn1 v24.2d, v2.2d, v6.2d \n\t" // Row 4. +" trn1 v25.2d, v10.2d, v14.2d \n\t" +" trn2 v26.2d, v2.2d, v6.2d \n\t" // Row 5. +" trn2 v27.2d, v10.2d, v14.2d \n\t" +" trn1 v28.2d, v3.2d, v7.2d \n\t" // Row 6. +" trn1 v29.2d, v11.2d, v15.2d \n\t" +" trn2 v30.2d, v3.2d, v7.2d \n\t" // Row 7. +" trn2 v31.2d, v11.2d, v15.2d \n\t" +// " ld1r {v14.2d}, [x4] \n\t" // Reload alpha & beta (value). +" ld1r {v15.2d}, [x8] \n\t" +" fcmp d15, #0.0 \n\t" +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x1,0,x6) +DSCALEA8V(16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,15,0) +// +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x1,0,x6) +DSCALEA8V(24,25,26,27,28,29,30,31,0,1,2,3,4,5,6,7,15,0) +LABEL(ZERO_BETA_R) +// +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +DSTOREC_4V_R_FWD(24,25,26,27,x5,0,x6) +DSTOREC_4V_R_FWD(28,29,30,31,x5,0,x6) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #8 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + a = a + m_iter * ps_a; + c = c + m_iter * 8 * rs_c; + // Edge case is within 1 millikernel loop of THIS kernel. + // Regarding the 6x?m kernel, the panel stride should be always local. + auxinfo_t data_6xkm = *data; + bli_auxinfo_set_ps_a( 6 * rs_a, &data_6xkm ); + if ( m_left ) + { + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, m_left, 4, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_6xkm, cntx + ); + } + + // Issue prefetch instructions only after + // execution is done. + __asm__ + ( +" mov x0, %[a_next] \n\t" +" mov x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, #16*0] \n\t" +" prfm PLDL1STRM, [x0, #16*1] \n\t" +" prfm PLDL1STRM, [x0, #16*2] \n\t" +" prfm PLDL1KEEP, [x1, #16*0] \n\t" +" prfm PLDL1KEEP, [x1, #16*1] \n\t" +" prfm PLDL1KEEP, [x1, #16*2] \n\t" +: +: [a_next] "r" (a_next), + [b_next] "r" (b_next) +: "x0", "x1" + ); +} + diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c new file mode 100644 index 0000000000..84c7c4a7d2 --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c @@ -0,0 +1,309 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary fixed-size gemmsup. + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "../../armv8a_asm_utils.h" + +#define DGEMM_3X1X2_NKER_SUBLOOP(C0,C1,C2,A0,A1,A2,B) \ +" fmla v"#C0".2d, v"#A0".2d, v"#B".2d \n\t" \ +" fmla v"#C1".2d, v"#A1".2d, v"#B".2d \n\t" \ +" fmla v"#C2".2d, v"#A2".2d, v"#B".2d \n\t" + +#define DGEMM_3X4X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X1X2_NKER_SUBLOOP(C00,C10,C20,A0,A1,A2,B0) \ + DGEMM_3X1X2_NKER_SUBLOOP(C01,C11,C21,A0,A1,A2,B1) \ + DGEMM_3X1X2_NKER_SUBLOOP(C02,C12,C22,A0,A1,A2,B2) \ + DGEMM_3X1X2_NKER_SUBLOOP(C03,C13,C23,A0,A1,A2,B3) + +// For row-storage of C. +#define DLOADC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DLOAD1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" ld1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DSTORE1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" st1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE6V(V0,V1,V2,V3,V4,V5,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V4,V5,A,IDX) +#define DSCALEA6V(D0,D1,D2,D3,D4,D5,S0,S1,S2,S3,S4,S5,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D4,D5,S4,S5,A,IDX) + +void bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + assert( m0 == 3 ); + assert( n0 == 4 ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:11] <- C +// V[12:14] <- A +// V[16:19] <- B +// Under this scheme, the following is defined: +#define DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X4X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,A0,A1,A2,B0,B1,B2,B3) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q16, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q17, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q18, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q19, [x11] \n\t" +" add x1, x1, #16 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q12, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q13, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q14, [x14] \n\t" +" add x0, x0, #16 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR4V(8,9,10,11) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3) \ + "mov x11, x1 \n\t" \ + "ldr q"#B0", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B1", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B2", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B3", [x11] \n\t" \ + "add x1, x1, #16 \n\t" \ + "mov x14, x0 \n\t" \ + "ldr q"#A0", [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q"#A1", [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q"#A2", [x14] \n\t" \ + "add x0, x0, #16 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(12,13,14,16,17,18,19) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(12,13,14,16,17,18,19) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC(12,13,14,16,17,18,19) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v1.2d \n\t" // Line 0. +" faddp v1.2d, v2.2d, v3.2d \n\t" +" faddp v2.2d, v4.2d, v5.2d \n\t" // Line 1. +" faddp v3.2d, v6.2d, v7.2d \n\t" +" faddp v4.2d, v8.2d, v9.2d \n\t" // Line 2. +" faddp v5.2d, v10.2d, v11.2d \n\t" +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" ld1 {v29.d}[1], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" fmla v0.2d, v28.2d, v24.d[0] \n\t" +" fmla v1.2d, v29.2d, v24.d[0] \n\t" +" fmla v2.2d, v28.2d, v24.d[1] \n\t" +" fmla v3.2d, v29.2d, v24.d[1] \n\t" +" fmla v4.2d, v28.2d, v25.d[0] \n\t" +" fmla v5.2d, v29.2d, v25.d[0] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +DSCALE6V(0,1,2,3,4,5,30,0) +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_R) +DLOADC_2V_R_FWD(12,13,x9,0,x6) +DLOADC_2V_R_FWD(14,15,x9,0,x6) +DLOADC_2V_R_FWD(16,17,x9,0,x6) +DSCALEA6V(0,1,2,3,4,5,12,13,14,15,16,17,31,0) +LABEL(ZERO_BETA_R) +DSTOREC_2V_R_FWD(0,1,x5,0,x6) +DSTOREC_2V_R_FWD(2,3,x5,0,x6) +DSTOREC_2V_R_FWD(4,5,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" trn1 v6.2d, v0.2d, v2.2d \n\t" +" trn2 v7.2d, v0.2d, v2.2d \n\t" +" trn1 v8.2d, v1.2d, v3.2d \n\t" +" trn2 v9.2d, v1.2d, v3.2d \n\t" +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_C) +DLOADC_1V_1ELM_C_FWD(12,20,0,x9,0,x7) +DLOADC_1V_1ELM_C_FWD(13,20,1,x9,0,x7) +DLOADC_1V_1ELM_C_FWD(14,21,0,x9,0,x7) +DLOADC_1V_1ELM_C_FWD(15,21,1,x9,0,x7) +DSCALEA6V(6,7,8,9,4,5,12,13,14,15,20,21,31,0) +LABEL(ZERO_BETA_C) +DSTOREC_1V_1ELM_C_FWD(6,4,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(7,4,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(8,5,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(9,5,1,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +} + diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c new file mode 100644 index 0000000000..abbb6fb4d9 --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c @@ -0,0 +1,359 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary fixed-size gemmsup. + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "../../armv8a_asm_utils.h" + +#define DGEMM_1X3X2_NKER_SUBLOOP(C0,C1,C2,A,B0,B1,B2) \ +" fmla v"#C0".2d, v"#A".2d, v"#B0".2d \n\t" \ +" fmla v"#C1".2d, v"#A".2d, v"#B1".2d \n\t" \ +" fmla v"#C2".2d, v"#A".2d, v"#B2".2d \n\t" + +#define DGEMM_6X3X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C10,C11,C12,C20,C21,C22,C30,C31,C32,C40,C41,C42,C50,C51,C52,A0,A1,A2,A3,A4,A5,B0,B1,B2,AADDR,AELEMADDR,AELEMST,LOAD0,LOAD1) \ + DGEMM_1X3X2_NKER_SUBLOOP(C00,C01,C02,A0,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A0,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C10,C11,C12,A1,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A1,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C20,C21,C22,A2,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A2,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C30,C31,C32,A3,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A3,AELEMADDR,AELEMST) \ + DGEMM_FWDA_K_ ##LOAD0 (AADDR) \ +" mov "#AELEMADDR", "#AADDR" \n\t" \ + DGEMM_1X3X2_NKER_SUBLOOP(C40,C41,C42,A4,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD1 (A4,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C50,C51,C52,A5,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD1 (A5,AELEMADDR,AELEMST) + +#define DGEMM_LOAD1V_K_noload(V,ELEMADDR,ELEMST) +#define DGEMM_LOAD1V_K_load(V,ELEMADDR,ELEMST) \ +" ldr q"#V", [ "#ELEMADDR" ] \n\t" \ +" add "#ELEMADDR", "#ELEMADDR", "#ELEMST" \n\t" + +#define DGEMM_FWDA_K_noload(ADDR) +#define DGEMM_FWDA_K_load(ADDR) \ +" add "#ADDR", "#ADDR", #16 \n\t" + +// For row-storage of C. +#define DLOADC_1V_1ELM_R_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,RSC) \ + DLOAD1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" ld1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_1V_1ELM_R_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,RSC) \ + DSTORE1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" st1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE9V(V0,V1,V2,V3,V4,V5,V6,V7,V8,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) \ + DSCALE1V(V8,A,IDX) +#define DSCALEA9V(D0,D1,D2,D3,D4,D5,D6,D7,D8,S0,S1,S2,S3,S4,S5,S6,S7,S8,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) \ + DSCALEA1V(D8,S8,A,IDX) + + +void bli_dgemmsup_rd_armv8a_asm_6x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + assert( m0 == 6 ); + assert( n0 == 3 ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 8; + uint64_t k_left = k0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:17] <- C +// V[18:23] <- B +// V[24:31] <- A +// Under this scheme, the following is defined: +#define DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,A4,A5,B0,B1,B2,AADDR,AELEMADDR,AELEMST,LOAD0,LOAD1) \ + DGEMM_6X3X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,A0,A1,A2,A3,A4,A5,B0,B1,B2,AADDR,AELEMADDR,AELEMST,LOAD0,LOAD1) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q24, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q25, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q26, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q27, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q28, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q29, [x14] \n\t" +" add x0, x0, #16 \n\t" +" mov x14, x0 \n\t" +" ldr q30, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q31, [x14] \n\t" +" add x14, x14, x2 \n\t" +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q18, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q19, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q20, [x11] \n\t" +" add x1, x1, #16 \n\t" +" mov x11, x1 \n\t" +" ldr q21, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q22, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q23, [x11] \n\t" +" add x1, x1, #16 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR2V(16,17) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,A4,A5,B0,B1,B2) \ + DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,A4,A5,B0,B1,B2,x0,x14,x2,load,load) \ + "mov x11, x1 \n\t" \ + "ldr q"#B0", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B1", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B2", [x11] \n\t" \ + "add x1, x1, #16 \n\t" \ +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,28,29,18,19,20) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(30,31,24,25,26,27,21,22,23) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(28,29,30,31,24,25,18,19,20) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,30,31,21,22,23) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(28,29,30,31,24,25,18,19,20,x0,x14,x2,load,noload) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(26,27,28,29,30,31,21,22,23,xzr,xzr,xzr,noload,noload) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v3.2d \n\t" // Column 0 Prt 0. +" faddp v1.2d, v1.2d, v4.2d \n\t" // Column 1 Prt 0. +" faddp v2.2d, v2.2d, v5.2d \n\t" // Column 2 Prt 0. +" faddp v3.2d, v6.2d, v9.2d \n\t" // Column 0 Prt 1. +" faddp v4.2d, v7.2d, v10.2d \n\t" // Column 1 Prt 1. +" faddp v5.2d, v8.2d, v11.2d \n\t" // Column 2 Prt 1. +" faddp v6.2d, v12.2d, v15.2d \n\t" // Column 0 Prt 2. +" faddp v7.2d, v13.2d, v16.2d \n\t" // Column 1 Prt 2. +" faddp v8.2d, v14.2d, v17.2d \n\t" // Column 2 Prt 2. +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" ld1 {v25.d}[1], [x14], x2 \n\t" +" ld1 {v26.d}[0], [x14], x2 \n\t" +" ld1 {v26.d}[1], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" fmla v0.2d, v24.2d, v28.d[0] \n\t" +" fmla v3.2d, v25.2d, v28.d[0] \n\t" +" fmla v6.2d, v26.2d, v28.d[0] \n\t" +" fmla v1.2d, v24.2d, v28.d[1] \n\t" +" fmla v4.2d, v25.2d, v28.d[1] \n\t" +" fmla v7.2d, v26.2d, v28.d[1] \n\t" +" fmla v2.2d, v24.2d, v29.d[0] \n\t" +" fmla v5.2d, v25.2d, v29.d[0] \n\t" +" fmla v8.2d, v26.2d, v29.d[0] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +DSCALE9V(0,1,2,3,4,5,6,7,8,30,0) +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" trn1 v20.2d, v0.2d, v1.2d \n\t" +" trn2 v21.2d, v0.2d, v1.2d \n\t" +" trn1 v22.2d, v3.2d, v4.2d \n\t" +" trn2 v23.2d, v3.2d, v4.2d \n\t" +" trn1 v24.2d, v6.2d, v7.2d \n\t" +" trn2 v25.2d, v6.2d, v7.2d \n\t" +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_R) +DLOADC_1V_1ELM_R_FWD(10,26,0,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(11,26,1,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(12,27,0,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(13,27,1,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(14,28,0,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(15,28,1,x9,0,x6) +DSCALEA9V(20,21,22,23,24,25,2,5,8,10,11,12,13,14,15,26,27,28,31,0) +LABEL(ZERO_BETA_R) +DSTOREC_1V_1ELM_R_FWD(20,2,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(21,2,1,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(22,5,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(23,5,1,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(24,8,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(25,8,1,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_C) +DLOADC_3V_C_FWD(12,15,18,x9,0,x7) +DLOADC_3V_C_FWD(13,16,19,x9,0,x7) +DLOADC_3V_C_FWD(14,17,20,x9,0,x7) +DSCALEA9V(0,1,2,3,4,5,6,7,8,12,13,14,15,16,17,18,19,20,31,0) +LABEL(ZERO_BETA_C) +DSTOREC_3V_C_FWD(0,3,6,x5,0,x7) +DSTOREC_3V_C_FWD(1,4,7,x5,0,x7) +DSTOREC_3V_C_FWD(2,5,8,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +} + + diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c new file mode 100644 index 0000000000..43880063eb --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c @@ -0,0 +1,383 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rd_armv8a_int_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + assert( m0 <= 2 ); + assert( n0 <= 8 ); + + double *a_loc = a; + double *b_loc = b; + double *c_loc = c; + + uint64_t k_mker = k0 / 2; + uint64_t k_left = k0 % 2; + uint64_t b_iszr = ( *beta == 0.0 ); + + assert( cs_a == 1 ); + assert( rs_b == 1 ); + + // Registers used to store a 2x8x2 block of C (summing the last dimension). + // Total: 22 specified. + float64x2_t vc_00, vc_01, vc_02, vc_03, vc_04, vc_05, vc_06, vc_07; + float64x2_t vc_10, vc_11, vc_12, vc_13, vc_14, vc_15, vc_16, vc_17; + float64x2_t va_0, va_1; + float64x2_t vb_0, vb_1, vb_2, vb_3; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); + vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_02 = (float64x2_t)vdupq_n_f64( 0 ); + vc_03 = (float64x2_t)vdupq_n_f64( 0 ); + vc_04 = (float64x2_t)vdupq_n_f64( 0 ); + vc_05 = (float64x2_t)vdupq_n_f64( 0 ); + vc_06 = (float64x2_t)vdupq_n_f64( 0 ); + vc_07 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); + vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_12 = (float64x2_t)vdupq_n_f64( 0 ); + vc_13 = (float64x2_t)vdupq_n_f64( 0 ); + vc_14 = (float64x2_t)vdupq_n_f64( 0 ); + vc_15 = (float64x2_t)vdupq_n_f64( 0 ); + vc_16 = (float64x2_t)vdupq_n_f64( 0 ); + vc_17 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k_mker > 0; --k_mker ) + { + // if ( m0 > 0 ) + va_0 = vld1q_f64( a_loc + rs_a * 0 ); + if ( m0 > 1 ) va_1 = vld1q_f64( a_loc + rs_a * 1 ); + // if ( n0 > 0 ) + vb_0 = vld1q_f64( b_loc + cs_b * 0 ); + if ( n0 > 1 ) vb_1 = vld1q_f64( b_loc + cs_b * 1 ); + if ( n0 > 2 ) vb_2 = vld1q_f64( b_loc + cs_b * 2 ); + if ( n0 > 3 ) vb_3 = vld1q_f64( b_loc + cs_b * 3 ); + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + if ( m0 > 1 ) + { + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + } + + if ( n0 > 4 ) { + vb_0 = vld1q_f64( b_loc + cs_b * 4 ); + if ( n0 > 5 ) vb_1 = vld1q_f64( b_loc + cs_b * 5 ); + if ( n0 > 6 ) vb_2 = vld1q_f64( b_loc + cs_b * 6 ); + if ( n0 > 7 ) vb_3 = vld1q_f64( b_loc + cs_b * 7 ); + + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + if ( n0 > 6 ) + { + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } + if ( m0 > 1 ) + { + vc_14 = vfmaq_f64( vc_14, va_1, vb_0 ); + vc_15 = vfmaq_f64( vc_15, va_1, vb_1 ); + if ( n0 > 6 ) + { + vc_16 = vfmaq_f64( vc_16, va_1, vb_2 ); + vc_17 = vfmaq_f64( vc_17, va_1, vb_3 ); + } + } + } + + a_loc += 2; + b_loc += 2; + } + + // Pay no care for O(1) details. + va_0 = (float64x2_t)vdupq_n_f64( 0 ); + va_1 = (float64x2_t)vdupq_n_f64( 0 ); + vb_0 = (float64x2_t)vdupq_n_f64( 0 ); + vb_1 = (float64x2_t)vdupq_n_f64( 0 ); + vb_2 = (float64x2_t)vdupq_n_f64( 0 ); + vb_3 = (float64x2_t)vdupq_n_f64( 0 ); + PRAGMA_NOUNROLL + for ( ; k_left > 0; --k_left ) + { + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 1, va_1, 0 ); + // if ( n0 > 0 ) + vb_0 = vld1q_lane_f64( b_loc + cs_b * 0, vb_0, 0 ); + if ( n0 > 1 ) vb_1 = vld1q_lane_f64( b_loc + cs_b * 1, vb_1, 0 ); + if ( n0 > 2 ) vb_2 = vld1q_lane_f64( b_loc + cs_b * 2, vb_2, 0 ); + if ( n0 > 3 ) vb_3 = vld1q_lane_f64( b_loc + cs_b * 3, vb_3, 0 ); + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + + if ( n0 > 4 ) vb_0 = vld1q_lane_f64( b_loc + cs_b * 4, vb_0, 0 ); + if ( n0 > 5 ) vb_1 = vld1q_lane_f64( b_loc + cs_b * 5, vb_1, 0 ); + if ( n0 > 6 ) vb_2 = vld1q_lane_f64( b_loc + cs_b * 6, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_lane_f64( b_loc + cs_b * 7, vb_3, 0 ); + + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + vc_14 = vfmaq_f64( vc_14, va_1, vb_0 ); + vc_15 = vfmaq_f64( vc_15, va_1, vb_1 ); + vc_16 = vfmaq_f64( vc_16, va_1, vb_2 ); + vc_17 = vfmaq_f64( vc_17, va_1, vb_3 ); + + a_loc += 1; + b_loc += 1; + } + + // Load alpha and beta. + // Note that here vb is used for alpha, in contrast to other kernels. + vb_0 = vld1q_dup_f64( alpha ); + va_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, vb_0 ); + vc_01 = vmulq_f64( vc_01, vb_0 ); + vc_02 = vmulq_f64( vc_02, vb_0 ); + vc_03 = vmulq_f64( vc_03, vb_0 ); + vc_04 = vmulq_f64( vc_04, vb_0 ); + vc_05 = vmulq_f64( vc_05, vb_0 ); + vc_06 = vmulq_f64( vc_06, vb_0 ); + vc_07 = vmulq_f64( vc_07, vb_0 ); + vc_10 = vmulq_f64( vc_10, vb_0 ); + vc_11 = vmulq_f64( vc_11, vb_0 ); + vc_12 = vmulq_f64( vc_12, vb_0 ); + vc_13 = vmulq_f64( vc_13, vb_0 ); + vc_14 = vmulq_f64( vc_14, vb_0 ); + vc_15 = vmulq_f64( vc_15, vb_0 ); + vc_16 = vmulq_f64( vc_16, vb_0 ); + vc_17 = vmulq_f64( vc_17, vb_0 ); + + if ( cs_c == 1 ) + { + // Row-storage. + vc_00 = vpaddq_f64( vc_00, vc_01 ); + vc_02 = vpaddq_f64( vc_02, vc_03 ); + vc_04 = vpaddq_f64( vc_04, vc_05 ); + vc_06 = vpaddq_f64( vc_06, vc_07 ); + vc_10 = vpaddq_f64( vc_10, vc_11 ); + vc_12 = vpaddq_f64( vc_12, vc_13 ); + vc_14 = vpaddq_f64( vc_14, vc_15 ); + vc_16 = vpaddq_f64( vc_16, vc_17 ); + + if ( n0 > 1 ) vb_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else if ( n0 > 0 ) vb_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, vb_0, 0 ); + if ( n0 > 3 ) vb_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n0 > 2 ) vb_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, vb_1, 0 ); + if ( n0 > 5 ) vb_2 = vld1q_f64 ( c_loc + 0 * rs_c + 4 ); + else if ( n0 > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 * rs_c + 4, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_f64 ( c_loc + 0 * rs_c + 6 ); + else if ( n0 > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 * rs_c + 6, vb_3, 0 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_1 ); + vc_04 = vfmaq_f64( vc_04, va_0, vb_2 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_3 ); + } + if ( n0 > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_02 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_02, 0 ); + if ( n0 > 5 ) vst1q_f64 ( c_loc + 0 * rs_c + 4, vc_04 ); + else if ( n0 > 4 ) vst1q_lane_f64( c_loc + 0 * rs_c + 4, vc_04, 0 ); + if ( n0 > 7 ) vst1q_f64 ( c_loc + 0 * rs_c + 6, vc_06 ); + else if ( n0 > 6 ) vst1q_lane_f64( c_loc + 0 * rs_c + 6, vc_06, 0 ); + + if ( m0 > 1 ) + { + if ( n0 > 1 ) vb_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else if ( n0 > 0 ) vb_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, vb_0, 0 ); + if ( n0 > 3 ) vb_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n0 > 2 ) vb_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, vb_1, 0 ); + if ( n0 > 5 ) vb_2 = vld1q_f64 ( c_loc + 1 * rs_c + 4 ); + else if ( n0 > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 * rs_c + 4, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_f64 ( c_loc + 1 * rs_c + 6 ); + else if ( n0 > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 * rs_c + 6, vb_3, 0 ); + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_0, vb_1 ); + vc_14 = vfmaq_f64( vc_14, va_0, vb_2 ); + vc_16 = vfmaq_f64( vc_16, va_0, vb_3 ); + } + if ( n0 > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_12 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_12, 0 ); + if ( n0 > 5 ) vst1q_f64 ( c_loc + 1 * rs_c + 4, vc_14 ); + else if ( n0 > 4 ) vst1q_lane_f64( c_loc + 1 * rs_c + 4, vc_14, 0 ); + if ( n0 > 7 ) vst1q_f64 ( c_loc + 1 * rs_c + 6, vc_16 ); + else if ( n0 > 6 ) vst1q_lane_f64( c_loc + 1 * rs_c + 6, vc_16, 0 ); + } + } + else + { + // Column-storage. + vc_00 = vpaddq_f64( vc_00, vc_10 ); + vc_01 = vpaddq_f64( vc_01, vc_11 ); + vc_02 = vpaddq_f64( vc_02, vc_12 ); + vc_03 = vpaddq_f64( vc_03, vc_13 ); + vc_04 = vpaddq_f64( vc_04, vc_14 ); + vc_05 = vpaddq_f64( vc_05, vc_15 ); + vc_06 = vpaddq_f64( vc_06, vc_16 ); + vc_07 = vpaddq_f64( vc_07, vc_17 ); + + if ( m0 > 1 ) + { + // if ( n0 > 0 ) + vb_0 = vld1q_f64( c_loc + 0 + 0 * cs_c ); + if ( n0 > 1 ) vb_1 = vld1q_f64( c_loc + 0 + 1 * cs_c ); + if ( n0 > 2 ) vb_2 = vld1q_f64( c_loc + 0 + 2 * cs_c ); + if ( n0 > 3 ) vb_3 = vld1q_f64( c_loc + 0 + 3 * cs_c ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + } + vst1q_f64( c_loc + 0 + 0 * cs_c, vc_00 ); + if ( n0 > 1 ) vst1q_f64( c_loc + 0 + 1 * cs_c, vc_01 ); + if ( n0 > 2 ) vst1q_f64( c_loc + 0 + 2 * cs_c, vc_02 ); + if ( n0 > 3 ) vst1q_f64( c_loc + 0 + 3 * cs_c, vc_03 ); + + if ( n0 > 4 ) vb_0 = vld1q_f64( c_loc + 0 + 4 * cs_c ); + if ( n0 > 5 ) vb_1 = vld1q_f64( c_loc + 0 + 5 * cs_c ); + if ( n0 > 6 ) vb_2 = vld1q_f64( c_loc + 0 + 6 * cs_c ); + if ( n0 > 7 ) vb_3 = vld1q_f64( c_loc + 0 + 7 * cs_c ); + if ( !b_iszr ) + { + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } + if ( n0 > 4 ) vst1q_f64( c_loc + 0 + 4 * cs_c, vc_04 ); + if ( n0 > 5 ) vst1q_f64( c_loc + 0 + 5 * cs_c, vc_05 ); + if ( n0 > 6 ) vst1q_f64( c_loc + 0 + 6 * cs_c, vc_06 ); + if ( n0 > 7 ) vst1q_f64( c_loc + 0 + 7 * cs_c, vc_07 ); + } + else + { + // if ( n0 > 0 ) + vb_0 = vld1q_lane_f64( c_loc + 0 + 0 * cs_c, vb_0, 0 ); + if ( n0 > 1 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, vb_1, 0 ); + if ( n0 > 2 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, vb_2, 0 ); + if ( n0 > 3 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, vb_3, 0 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + } + vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); + if ( n0 > 1 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_01, 0 ); + if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_02, 0 ); + if ( n0 > 3 ) vst1q_lane_f64( c_loc + 0 + 3 * cs_c, vc_03, 0 ); + + if ( n0 > 4 ) vb_0 = vld1q_lane_f64( c_loc + 0 + 4 * cs_c, vb_0, 0 ); + if ( n0 > 5 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 5 * cs_c, vb_1, 0 ); + if ( n0 > 6 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 6 * cs_c, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 7 * cs_c, vb_3, 0 ); + if ( !b_iszr ) + { + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } + if ( n0 > 4 ) vst1q_lane_f64( c_loc + 0 + 4 * cs_c, vc_04, 0 ); + if ( n0 > 5 ) vst1q_lane_f64( c_loc + 0 + 5 * cs_c, vc_05, 0 ); + if ( n0 > 6 ) vst1q_lane_f64( c_loc + 0 + 6 * cs_c, vc_06, 0 ); + if ( n0 > 7 ) vst1q_lane_f64( c_loc + 0 + 7 * cs_c, vc_07, 0 ); + } + } + +} diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c new file mode 100644 index 0000000000..73e5f20fb7 --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c @@ -0,0 +1,341 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rd_armv8a_int_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // if ( m0 == 3 && n0 == 4 ) + // { + // // Use fixed-size version if it is full 3x4. + // bli_dgemmsup_rd_armv8a_asm_3x4 + // ( + // conja, conjb, m0, n0, k0, + // alpha, a, rs_a, cs_a, b, rs_b, cs_b, + // beta, c, rs_c, cs_c, data, cntx + // ); + // return; + // } + + assert( m0 <= 3 ); + assert( n0 <= 4 ); + + double *a_loc = a; + double *b_loc = b; + double *c_loc = c; + + uint64_t k_mker = k0 / 2; + uint64_t k_left = k0 % 2; + uint64_t b_iszr = ( *beta == 0.0 ); + + assert( cs_a == 1 ); + assert( rs_b == 1 ); + + // Registers used to store a 3x4x2 block of C (summing the last dimension). + float64x2_t vc_00, vc_01, vc_02, vc_03; + float64x2_t vc_10, vc_11, vc_12, vc_13; + float64x2_t vc_20, vc_21, vc_22, vc_23; + float64x2_t va_0, va_1, va_2; + float64x2_t vb_0, vb_1, vb_2, vb_3; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); + vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_02 = (float64x2_t)vdupq_n_f64( 0 ); + vc_03 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); + vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_12 = (float64x2_t)vdupq_n_f64( 0 ); + vc_13 = (float64x2_t)vdupq_n_f64( 0 ); + vc_20 = (float64x2_t)vdupq_n_f64( 0 ); + vc_21 = (float64x2_t)vdupq_n_f64( 0 ); + vc_22 = (float64x2_t)vdupq_n_f64( 0 ); + vc_23 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k_mker > 0; --k_mker ) + { + // if ( m0 > 0 ) + va_0 = vld1q_f64( a_loc + rs_a * 0 ); + if ( m0 > 1 ) va_1 = vld1q_f64( a_loc + rs_a * 1 ); + if ( m0 > 2 ) va_2 = vld1q_f64( a_loc + rs_a * 2 ); + // if ( n0 > 0 ) + vb_0 = vld1q_f64( b_loc + cs_b * 0 ); + if ( n0 > 1 ) vb_1 = vld1q_f64( b_loc + cs_b * 1 ); + if ( n0 > 2 ) vb_2 = vld1q_f64( b_loc + cs_b * 2 ); + if ( n0 > 3 ) vb_3 = vld1q_f64( b_loc + cs_b * 3 ); + a_loc += 2; + b_loc += 2; + + // 1-column case. + if ( n0 == 1 ) { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + continue; + } + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + if ( m0 > 1 ) + { + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + } + if ( m0 > 2 ) { + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + vc_21 = vfmaq_f64( vc_21, va_2, vb_1 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_2 ); + vc_23 = vfmaq_f64( vc_23, va_2, vb_3 ); + } + } + + // Pay no care for O(1) details. + va_0 = (float64x2_t)vdupq_n_f64( 0 ); + va_1 = (float64x2_t)vdupq_n_f64( 0 ); + va_2 = (float64x2_t)vdupq_n_f64( 0 ); + vb_0 = (float64x2_t)vdupq_n_f64( 0 ); + vb_1 = (float64x2_t)vdupq_n_f64( 0 ); + vb_2 = (float64x2_t)vdupq_n_f64( 0 ); + vb_3 = (float64x2_t)vdupq_n_f64( 0 ); + PRAGMA_NOUNROLL + for ( ; k_left > 0; --k_left ) + { + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 1, va_1, 0 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( a_loc + rs_a * 2, va_2, 0 ); + // if ( n0 > 0 ) + vb_0 = vld1q_lane_f64( b_loc + cs_b * 0, vb_0, 0 ); + if ( n0 > 1 ) vb_1 = vld1q_lane_f64( b_loc + cs_b * 1, vb_1, 0 ); + if ( n0 > 2 ) vb_2 = vld1q_lane_f64( b_loc + cs_b * 2, vb_2, 0 ); + if ( n0 > 3 ) vb_3 = vld1q_lane_f64( b_loc + cs_b * 3, vb_3, 0 ); + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + vc_21 = vfmaq_f64( vc_21, va_2, vb_1 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_2 ); + vc_23 = vfmaq_f64( vc_23, va_2, vb_3 ); + + a_loc += 1; + b_loc += 1; + } + + // Reduce. + vc_00 = vpaddq_f64( vc_00, vc_01 ); + vc_02 = vpaddq_f64( vc_02, vc_03 ); + vc_10 = vpaddq_f64( vc_10, vc_11 ); + vc_12 = vpaddq_f64( vc_12, vc_13 ); + vc_20 = vpaddq_f64( vc_20, vc_21 ); + vc_22 = vpaddq_f64( vc_22, vc_23 ); + + // Load alpha and beta. + va_0 = vld1q_dup_f64( alpha ); + vb_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, va_0 ); + vc_02 = vmulq_f64( vc_02, va_0 ); + vc_10 = vmulq_f64( vc_10, va_0 ); + vc_12 = vmulq_f64( vc_12, va_0 ); + vc_20 = vmulq_f64( vc_20, va_0 ); + vc_22 = vmulq_f64( vc_22, va_0 ); + + if ( cs_c == 1 ) + { + // Row-storage. + // if ( m0 > 0 ) + { + if ( n0 > 1 ) va_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else if ( n0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, va_0, 0 ); + if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, va_1, 0 ); + + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_02 = vfmaq_f64( vc_02, va_1, vb_0 ); + } + + if ( n0 > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_02 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_02, 0 ); + } + if ( m0 > 1 ) + { + if ( n0 > 1 ) va_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else if ( n0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, va_0, 0 ); + if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, va_1, 0 ); + + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + } + + if ( n0 > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_12 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_12, 0 ); + } + if ( m0 > 2 ) + { + if ( n0 > 1 ) va_0 = vld1q_f64 ( c_loc + 2 * rs_c + 0 ); + else if ( n0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 2 * rs_c + 0, va_0, 0 ); + if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); + else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, va_1, 0 ); + + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, va_0, vb_0 ); + vc_22 = vfmaq_f64( vc_22, va_1, vb_0 ); + } + + if ( n0 > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 2 * rs_c + 2, vc_22 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 2 * rs_c + 2, vc_22, 0 ); + } + } + else + { + // Column-storage. + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 0 * cs_c, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 0 * cs_c, va_1, 0 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 0 * cs_c, va_2, 0 ); + if ( n0 > 1 ) + { + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, va_0, 1 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 1 * cs_c, va_1, 1 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 1 * cs_c, va_2, 1 ); + } + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + } + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 0 * cs_c, vc_10, 0 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 0 * cs_c, vc_20, 0 ); + if ( n0 > 1 ) + { + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_00, 1 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 1 * cs_c, vc_10, 1 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 1 * cs_c, vc_20, 1 ); + } + + if ( n0 > 2 ) + { + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 2 * cs_c, va_1, 0 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 2 * cs_c, va_2, 0 ); + } + if ( n0 > 3 ) + { + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, va_0, 1 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 3 * cs_c, va_1, 1 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 3 * cs_c, va_2, 1 ); + } + if ( !b_iszr ) + { + vc_02 = vfmaq_f64( vc_02, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_0 ); + } + if ( n0 > 2 ) + { + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_02, 0 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 2 * cs_c, vc_12, 0 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 2 * cs_c, vc_22, 0 ); + } + if ( n0 > 3 ) + { + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 3 * cs_c, vc_02, 1 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 3 * cs_c, vc_12, 1 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 3 * cs_c, vc_22, 1 ); + } + } + +} + diff --git a/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d3x8mn.c b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d3x8mn.c new file mode 100644 index 0000000000..16af42ade6 --- /dev/null +++ b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d3x8mn.c @@ -0,0 +1,393 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rv_armv8a_int_3x8mn + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a0, inc_t rs_a, inc_t cs_a, + double* restrict b0, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c0, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Unlike the rd case, this rv case does not impose restriction upon + // maximal m & n. + + double *a_loc; + double *b_loc, *b_in; + double *c_loc, *c_in; + + dim_t n; + dim_t k; + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t b_iszr = ( *beta == 0.0 ); + assert( cs_b == 1 ); + + // Registers used to store a 3x8 block of C. + float64x2_t vc_00, vc_01, vc_02, vc_03; + float64x2_t vc_10, vc_11, vc_12, vc_13; + float64x2_t vc_20, vc_21, vc_22, vc_23; + float64x2_t va_0, va_1; + float64x2_t vb_0, vb_1, vb_2, vb_3; + + PRAGMA_NOUNROLL + for ( ; m0 > 0; m0 -= 3 ) + { + n = n0; + b_in = b0; + c_in = c0; + + PRAGMA_NOUNROLL + for ( ; n > 0; n -= 8 ) + { + a_loc = a0; + b_loc = b_in; + c_loc = c_in; + k = k0; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); + vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_02 = (float64x2_t)vdupq_n_f64( 0 ); + vc_03 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); + vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_12 = (float64x2_t)vdupq_n_f64( 0 ); + vc_13 = (float64x2_t)vdupq_n_f64( 0 ); + vc_20 = (float64x2_t)vdupq_n_f64( 0 ); + vc_21 = (float64x2_t)vdupq_n_f64( 0 ); + vc_22 = (float64x2_t)vdupq_n_f64( 0 ); + vc_23 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k > 0; --k ) + { + // A columns. + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_0 = vld1q_lane_f64( a_loc + rs_a * 1, va_0, 1 ); + if ( m0 > 2 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 2, va_1, 0 ); + // B rows. + if ( n > 1 ) vb_0 = vld1q_f64 ( b_loc + 0 ); + else vb_0 = vld1q_lane_f64( b_loc + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( b_loc + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( b_loc + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( b_loc + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( b_loc + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( b_loc + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( b_loc + 6, vb_3, 0 ); + a_loc += cs_a; + b_loc += rs_b; + + // if ( m0 > 0 ) + { + vc_00 = vfmaq_laneq_f64( vc_00, vb_0, va_0, 0 ); + vc_01 = vfmaq_laneq_f64( vc_01, vb_1, va_0, 0 ); + vc_02 = vfmaq_laneq_f64( vc_02, vb_2, va_0, 0 ); + vc_03 = vfmaq_laneq_f64( vc_03, vb_3, va_0, 0 ); + } + if ( m0 > 1 ) + { + vc_10 = vfmaq_laneq_f64( vc_10, vb_0, va_0, 1 ); + vc_11 = vfmaq_laneq_f64( vc_11, vb_1, va_0, 1 ); + vc_12 = vfmaq_laneq_f64( vc_12, vb_2, va_0, 1 ); + vc_13 = vfmaq_laneq_f64( vc_13, vb_3, va_0, 1 ); + } + if ( m0 > 2 ) + { + vc_20 = vfmaq_laneq_f64( vc_20, vb_0, va_1, 0 ); + vc_21 = vfmaq_laneq_f64( vc_21, vb_1, va_1, 0 ); + vc_22 = vfmaq_laneq_f64( vc_22, vb_2, va_1, 0 ); + vc_23 = vfmaq_laneq_f64( vc_23, vb_3, va_1, 0 ); + } + } + + // Load alpha and beta. + // Note that here vb is used for alpha, in contrast to other kernels. + vb_0 = vld1q_dup_f64( alpha ); + va_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, vb_0 ); + vc_01 = vmulq_f64( vc_01, vb_0 ); + vc_02 = vmulq_f64( vc_02, vb_0 ); + vc_03 = vmulq_f64( vc_03, vb_0 ); + vc_10 = vmulq_f64( vc_10, vb_0 ); + vc_11 = vmulq_f64( vc_11, vb_0 ); + vc_12 = vmulq_f64( vc_12, vb_0 ); + vc_13 = vmulq_f64( vc_13, vb_0 ); + vc_20 = vmulq_f64( vc_20, vb_0 ); + vc_21 = vmulq_f64( vc_21, vb_0 ); + vc_22 = vmulq_f64( vc_22, vb_0 ); + vc_23 = vmulq_f64( vc_23, vb_0 ); + + if ( cs_c == 1 ) + { + // Store in rows. + // + // if ( m0 > 0 ) + { + // Load. + if ( n > 1 ) vb_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else vb_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( c_loc + 0 * rs_c + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 * rs_c + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( c_loc + 0 * rs_c + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 * rs_c + 6, vb_3, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, vb_0, va_0 ); + vc_01 = vfmaq_f64( vc_01, vb_1, va_0 ); + vc_02 = vfmaq_f64( vc_02, vb_2, va_0 ); + vc_03 = vfmaq_f64( vc_03, vb_3, va_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_01 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_01, 0 ); + if ( n > 5 ) vst1q_f64 ( c_loc + 0 * rs_c + 4, vc_02 ); + else if ( n > 4 ) vst1q_lane_f64( c_loc + 0 * rs_c + 4, vc_02, 0 ); + if ( n > 7 ) vst1q_f64 ( c_loc + 0 * rs_c + 6, vc_03 ); + else if ( n > 6 ) vst1q_lane_f64( c_loc + 0 * rs_c + 6, vc_03, 0 ); + } + if ( m0 > 1 ) + { + // Load. + if ( n > 1 ) vb_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else vb_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( c_loc + 1 * rs_c + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 * rs_c + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( c_loc + 1 * rs_c + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 * rs_c + 6, vb_3, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, vb_0, va_0 ); + vc_11 = vfmaq_f64( vc_11, vb_1, va_0 ); + vc_12 = vfmaq_f64( vc_12, vb_2, va_0 ); + vc_13 = vfmaq_f64( vc_13, vb_3, va_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_11 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_11, 0 ); + if ( n > 5 ) vst1q_f64 ( c_loc + 1 * rs_c + 4, vc_12 ); + else if ( n > 4 ) vst1q_lane_f64( c_loc + 1 * rs_c + 4, vc_12, 0 ); + if ( n > 7 ) vst1q_f64 ( c_loc + 1 * rs_c + 6, vc_13 ); + else if ( n > 6 ) vst1q_lane_f64( c_loc + 1 * rs_c + 6, vc_13, 0 ); + } + if ( m0 > 2 ) + { + // Load. + if ( n > 1 ) vb_0 = vld1q_f64 ( c_loc + 2 * rs_c + 0 ); + else vb_0 = vld1q_lane_f64( c_loc + 2 * rs_c + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( c_loc + 2 * rs_c + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 2 * rs_c + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( c_loc + 2 * rs_c + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 2 * rs_c + 6, vb_3, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, vb_0, va_0 ); + vc_21 = vfmaq_f64( vc_21, vb_1, va_0 ); + vc_22 = vfmaq_f64( vc_22, vb_2, va_0 ); + vc_23 = vfmaq_f64( vc_23, vb_3, va_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); + else vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 2 * rs_c + 2, vc_21 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * rs_c + 2, vc_21, 0 ); + if ( n > 5 ) vst1q_f64 ( c_loc + 2 * rs_c + 4, vc_22 ); + else if ( n > 4 ) vst1q_lane_f64( c_loc + 2 * rs_c + 4, vc_22, 0 ); + if ( n > 7 ) vst1q_f64 ( c_loc + 2 * rs_c + 6, vc_23 ); + else if ( n > 6 ) vst1q_lane_f64( c_loc + 2 * rs_c + 6, vc_23, 0 ); + } + } + else + { + // Store in columns. + // No in-reg transpose here. + // + // if ( m0 > 0 ) + { + // Load. + if ( n > 0 ) vb_0 = vld1q_lane_f64( c_loc + 0 + 0 * cs_c, vb_0, 0 ); + if ( n > 1 ) vb_0 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, vb_0, 1 ); + if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, vb_1, 0 ); + if ( n > 3 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, vb_1, 1 ); + if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 4 * cs_c, vb_2, 0 ); + if ( n > 5 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 5 * cs_c, vb_2, 1 ); + if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 6 * cs_c, vb_3, 0 ); + if ( n > 7 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 7 * cs_c, vb_3, 1 ); + + // Scale. + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, vb_0, va_0 ); + vc_01 = vfmaq_f64( vc_01, vb_1, va_0 ); + vc_02 = vfmaq_f64( vc_02, vb_2, va_0 ); + vc_03 = vfmaq_f64( vc_03, vb_3, va_0 ); + } + + // Store. + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_00, 1 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_01, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 0 + 3 * cs_c, vc_01, 1 ); + if ( n > 4 ) vst1q_lane_f64( c_loc + 0 + 4 * cs_c, vc_02, 0 ); + if ( n > 5 ) vst1q_lane_f64( c_loc + 0 + 5 * cs_c, vc_02, 1 ); + if ( n > 6 ) vst1q_lane_f64( c_loc + 0 + 6 * cs_c, vc_03, 0 ); + if ( n > 7 ) vst1q_lane_f64( c_loc + 0 + 7 * cs_c, vc_03, 1 ); + } + if ( m0 > 1 ) + { + // Load. + if ( n > 0 ) vb_0 = vld1q_lane_f64( c_loc + 1 + 0 * cs_c, vb_0, 0 ); + if ( n > 1 ) vb_0 = vld1q_lane_f64( c_loc + 1 + 1 * cs_c, vb_0, 1 ); + if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 1 + 2 * cs_c, vb_1, 0 ); + if ( n > 3 ) vb_1 = vld1q_lane_f64( c_loc + 1 + 3 * cs_c, vb_1, 1 ); + if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 + 4 * cs_c, vb_2, 0 ); + if ( n > 5 ) vb_2 = vld1q_lane_f64( c_loc + 1 + 5 * cs_c, vb_2, 1 ); + if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 + 6 * cs_c, vb_3, 0 ); + if ( n > 7 ) vb_3 = vld1q_lane_f64( c_loc + 1 + 7 * cs_c, vb_3, 1 ); + + // Scale. + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, vb_0, va_0 ); + vc_11 = vfmaq_f64( vc_11, vb_1, va_0 ); + vc_12 = vfmaq_f64( vc_12, vb_2, va_0 ); + vc_13 = vfmaq_f64( vc_13, vb_3, va_0 ); + } + + // Store. + if ( n > 0 ) vst1q_lane_f64( c_loc + 1 + 0 * cs_c, vc_10, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 + 1 * cs_c, vc_10, 1 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 1 + 2 * cs_c, vc_11, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 1 + 3 * cs_c, vc_11, 1 ); + if ( n > 4 ) vst1q_lane_f64( c_loc + 1 + 4 * cs_c, vc_12, 0 ); + if ( n > 5 ) vst1q_lane_f64( c_loc + 1 + 5 * cs_c, vc_12, 1 ); + if ( n > 6 ) vst1q_lane_f64( c_loc + 1 + 6 * cs_c, vc_13, 0 ); + if ( n > 7 ) vst1q_lane_f64( c_loc + 1 + 7 * cs_c, vc_13, 1 ); + } + if ( m0 > 2 ) + { + // Load. + if ( n > 0 ) vb_0 = vld1q_lane_f64( c_loc + 2 + 0 * cs_c, vb_0, 0 ); + if ( n > 1 ) vb_0 = vld1q_lane_f64( c_loc + 2 + 1 * cs_c, vb_0, 1 ); + if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 2 + 2 * cs_c, vb_1, 0 ); + if ( n > 3 ) vb_1 = vld1q_lane_f64( c_loc + 2 + 3 * cs_c, vb_1, 1 ); + if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 2 + 4 * cs_c, vb_2, 0 ); + if ( n > 5 ) vb_2 = vld1q_lane_f64( c_loc + 2 + 5 * cs_c, vb_2, 1 ); + if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 2 + 6 * cs_c, vb_3, 0 ); + if ( n > 7 ) vb_3 = vld1q_lane_f64( c_loc + 2 + 7 * cs_c, vb_3, 1 ); + + // Scale. + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, vb_0, va_0 ); + vc_21 = vfmaq_f64( vc_21, vb_1, va_0 ); + vc_22 = vfmaq_f64( vc_22, vb_2, va_0 ); + vc_23 = vfmaq_f64( vc_23, vb_3, va_0 ); + } + + // Store. + if ( n > 0 ) vst1q_lane_f64( c_loc + 2 + 0 * cs_c, vc_20, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 2 + 1 * cs_c, vc_20, 1 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 + 2 * cs_c, vc_21, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 2 + 3 * cs_c, vc_21, 1 ); + if ( n > 4 ) vst1q_lane_f64( c_loc + 2 + 4 * cs_c, vc_22, 0 ); + if ( n > 5 ) vst1q_lane_f64( c_loc + 2 + 5 * cs_c, vc_22, 1 ); + if ( n > 6 ) vst1q_lane_f64( c_loc + 2 + 6 * cs_c, vc_23, 0 ); + if ( n > 7 ) vst1q_lane_f64( c_loc + 2 + 7 * cs_c, vc_23, 1 ); + } + } + + b_in += ps_b; + c_in += 8 * cs_c; + } + + a0 += ps_a; + c0 += 3 * rs_c; + } +} + diff --git a/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d6x4mn.c b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d6x4mn.c new file mode 100644 index 0000000000..8bbd87f1f6 --- /dev/null +++ b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d6x4mn.c @@ -0,0 +1,481 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a0, inc_t rs_a, inc_t cs_a, + double* restrict b0, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c0, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Unlike the rd case, this rv case does not impose restriction upon + // maximal m & n. + + double *a_loc; + double *b_loc, *b_in; + double *c_loc, *c_in; + + dim_t n; + dim_t k; + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t b_iszr = ( *beta == 0.0 ); + assert( cs_b == 1 ); + + // Registers used to store a 6x4 block of C. + float64x2_t vc_00, vc_01; + float64x2_t vc_10, vc_11; + float64x2_t vc_20, vc_21; + float64x2_t vc_30, vc_31; + float64x2_t vc_40, vc_41; + float64x2_t vc_50, vc_51; + float64x2_t va_0, va_1, va_2; + float64x2_t vb_0, vb_1; + + PRAGMA_NOUNROLL + for ( ; m0 > 0; m0 -= 6 ) + { + n = n0; + b_in = b0; + c_in = c0; + + PRAGMA_NOUNROLL + for ( ; n > 0; n -= 4 ) + { + a_loc = a0; + b_loc = b_in; + c_loc = c_in; + k = k0; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_20 = (float64x2_t)vdupq_n_f64( 0 ); vc_21 = (float64x2_t)vdupq_n_f64( 0 ); + vc_30 = (float64x2_t)vdupq_n_f64( 0 ); vc_31 = (float64x2_t)vdupq_n_f64( 0 ); + vc_40 = (float64x2_t)vdupq_n_f64( 0 ); vc_41 = (float64x2_t)vdupq_n_f64( 0 ); + vc_50 = (float64x2_t)vdupq_n_f64( 0 ); vc_51 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k > 0; --k ) + { + // A columns. + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_0 = vld1q_lane_f64( a_loc + rs_a * 1, va_0, 1 ); + if ( m0 > 2 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 2, va_1, 0 ); + if ( m0 > 3 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 3, va_1, 1 ); + if ( m0 > 4 ) va_2 = vld1q_lane_f64( a_loc + rs_a * 4, va_2, 0 ); + if ( m0 > 5 ) va_2 = vld1q_lane_f64( a_loc + rs_a * 5, va_2, 1 ); + // B rows. + if ( n > 1 ) vb_0 = vld1q_f64 ( b_loc + 0 ); + else vb_0 = vld1q_lane_f64( b_loc + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( b_loc + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( b_loc + 2, vb_1, 0 ); + a_loc += cs_a; + b_loc += rs_b; + + // One or two-column case. + if ( n <= 2 ) + { + // if ( m0 > 0 ) + { + vc_00 = vfmaq_laneq_f64( vc_00, vb_0, va_0, 0 ); + vc_10 = vfmaq_laneq_f64( vc_10, vb_0, va_0, 1 ); + vc_20 = vfmaq_laneq_f64( vc_20, vb_0, va_1, 0 ); + } + if ( m0 > 3 ) + { + vc_30 = vfmaq_laneq_f64( vc_30, vb_0, va_1, 1 ); + vc_40 = vfmaq_laneq_f64( vc_40, vb_0, va_2, 0 ); + vc_50 = vfmaq_laneq_f64( vc_50, vb_0, va_2, 1 ); + } + continue; + } + + // Three or four-column case. Moderately decrease num. of FMLA instructions + // according to m and n. + // if ( m0 > 0 ) + { + vc_00 = vfmaq_laneq_f64( vc_00, vb_0, va_0, 0 ); + vc_01 = vfmaq_laneq_f64( vc_01, vb_1, va_0, 0 ); + vc_10 = vfmaq_laneq_f64( vc_10, vb_0, va_0, 1 ); + vc_11 = vfmaq_laneq_f64( vc_11, vb_1, va_0, 1 ); + } + if ( m0 > 2 ) + { + vc_20 = vfmaq_laneq_f64( vc_20, vb_0, va_1, 0 ); + vc_21 = vfmaq_laneq_f64( vc_21, vb_1, va_1, 0 ); + vc_30 = vfmaq_laneq_f64( vc_30, vb_0, va_1, 1 ); + vc_31 = vfmaq_laneq_f64( vc_31, vb_1, va_1, 1 ); + } + if ( m0 > 4 ) + { + vc_40 = vfmaq_laneq_f64( vc_40, vb_0, va_2, 0 ); + vc_41 = vfmaq_laneq_f64( vc_41, vb_1, va_2, 0 ); + vc_50 = vfmaq_laneq_f64( vc_50, vb_0, va_2, 1 ); + vc_51 = vfmaq_laneq_f64( vc_51, vb_1, va_2, 1 ); + } + } + + // Load alpha and beta. + va_0 = vld1q_dup_f64( alpha ); + vb_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, va_0 ); vc_01 = vmulq_f64( vc_01, va_0 ); + vc_10 = vmulq_f64( vc_10, va_0 ); vc_11 = vmulq_f64( vc_11, va_0 ); + vc_20 = vmulq_f64( vc_20, va_0 ); vc_21 = vmulq_f64( vc_21, va_0 ); + vc_30 = vmulq_f64( vc_30, va_0 ); vc_31 = vmulq_f64( vc_31, va_0 ); + vc_40 = vmulq_f64( vc_40, va_0 ); vc_41 = vmulq_f64( vc_41, va_0 ); + vc_50 = vmulq_f64( vc_50, va_0 ); vc_51 = vmulq_f64( vc_51, va_0 ); + + if ( cs_c == 1 ) + { + // Store in rows. + // if ( m0 > 0 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_01 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_01, 0 ); + } + if ( m0 > 1 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_11 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_11, 0 ); + } + if ( m0 > 2 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 2 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 2 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, va_0, vb_0 ); + vc_21 = vfmaq_f64( vc_21, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); + else vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 2 * rs_c + 2, vc_21 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * rs_c + 2, vc_21, 0 ); + } + if ( m0 > 3 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 3 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 3 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 3 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 3 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_30 = vfmaq_f64( vc_30, va_0, vb_0 ); + vc_31 = vfmaq_f64( vc_31, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 3 * rs_c + 0, vc_30 ); + else vst1q_lane_f64( c_loc + 3 * rs_c + 0, vc_30, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 3 * rs_c + 2, vc_31 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 3 * rs_c + 2, vc_31, 0 ); + } + if ( m0 > 4 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 4 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 4 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 4 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 4 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_40 = vfmaq_f64( vc_40, va_0, vb_0 ); + vc_41 = vfmaq_f64( vc_41, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 4 * rs_c + 0, vc_40 ); + else vst1q_lane_f64( c_loc + 4 * rs_c + 0, vc_40, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 4 * rs_c + 2, vc_41 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 4 * rs_c + 2, vc_41, 0 ); + } + if ( m0 > 5 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 5 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 5 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 5 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 5 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_50 = vfmaq_f64( vc_50, va_0, vb_0 ); + vc_51 = vfmaq_f64( vc_51, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 5 * rs_c + 0, vc_50 ); + else vst1q_lane_f64( c_loc + 5 * rs_c + 0, vc_50, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 5 * rs_c + 2, vc_51 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 5 * rs_c + 2, vc_51, 0 ); + } + } + else + { + // Store in columns. + + // Rename some vectors. +#define VCOL0 va_0 +#define VCOL1 va_1 +#define VCOL2 va_2 +#define VCOL3 vb_1 +#define VTMP0 vc_00 +#define VTMP1 vc_01 +#define VTMP2 vc_10 +#define VTMP3 vc_11 + // if ( m0 > 0 ) + { + VCOL0 = vtrn1q_f64(vc_00, vc_10); + VCOL1 = vtrn2q_f64(vc_00, vc_10); + VCOL2 = vtrn1q_f64(vc_01, vc_11); + VCOL3 = vtrn2q_f64(vc_01, vc_11); + + if ( m0 > 1 ) + { + if ( n > 0 ) VTMP0 = vld1q_f64( c_loc + 0 * cs_c + 0 ); + if ( n > 1 ) VTMP1 = vld1q_f64( c_loc + 1 * cs_c + 0 ); + if ( n > 2 ) VTMP2 = vld1q_f64( c_loc + 2 * cs_c + 0 ); + if ( n > 3 ) VTMP3 = vld1q_f64( c_loc + 3 * cs_c + 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_f64( c_loc + 0 * cs_c + 0, VCOL0 ); + if ( n > 1 ) vst1q_f64( c_loc + 1 * cs_c + 0, VCOL1 ); + if ( n > 2 ) vst1q_f64( c_loc + 2 * cs_c + 0, VCOL2 ); + if ( n > 3 ) vst1q_f64( c_loc + 3 * cs_c + 0, VCOL3 ); + } + else + { + if ( n > 0 ) VTMP0 = vld1q_lane_f64( c_loc + 0 * cs_c + 0, VTMP0, 0 ); + if ( n > 1 ) VTMP1 = vld1q_lane_f64( c_loc + 1 * cs_c + 0, VTMP1, 0 ); + if ( n > 2 ) VTMP2 = vld1q_lane_f64( c_loc + 2 * cs_c + 0, VTMP2, 0 ); + if ( n > 3 ) VTMP3 = vld1q_lane_f64( c_loc + 3 * cs_c + 0, VTMP3, 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 * cs_c + 0, VCOL0, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 * cs_c + 0, VCOL1, 0 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * cs_c + 0, VCOL2, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 3 * cs_c + 0, VCOL3, 0 ); + } + } + if ( m0 > 2 ) + { + VCOL0 = vtrn1q_f64(vc_20, vc_30); + VCOL1 = vtrn2q_f64(vc_20, vc_30); + VCOL2 = vtrn1q_f64(vc_21, vc_31); + VCOL3 = vtrn2q_f64(vc_21, vc_31); + + if ( m0 > 3 ) + { + if ( n > 0 ) VTMP0 = vld1q_f64( c_loc + 0 * cs_c + 2 ); + if ( n > 1 ) VTMP1 = vld1q_f64( c_loc + 1 * cs_c + 2 ); + if ( n > 2 ) VTMP2 = vld1q_f64( c_loc + 2 * cs_c + 2 ); + if ( n > 3 ) VTMP3 = vld1q_f64( c_loc + 3 * cs_c + 2 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_f64( c_loc + 0 * cs_c + 2, VCOL0 ); + if ( n > 1 ) vst1q_f64( c_loc + 1 * cs_c + 2, VCOL1 ); + if ( n > 2 ) vst1q_f64( c_loc + 2 * cs_c + 2, VCOL2 ); + if ( n > 3 ) vst1q_f64( c_loc + 3 * cs_c + 2, VCOL3 ); + } + else + { + if ( n > 0 ) VTMP0 = vld1q_lane_f64( c_loc + 0 * cs_c + 2, VTMP0, 0 ); + if ( n > 1 ) VTMP1 = vld1q_lane_f64( c_loc + 1 * cs_c + 2, VTMP1, 0 ); + if ( n > 2 ) VTMP2 = vld1q_lane_f64( c_loc + 2 * cs_c + 2, VTMP2, 0 ); + if ( n > 3 ) VTMP3 = vld1q_lane_f64( c_loc + 3 * cs_c + 2, VTMP3, 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 * cs_c + 2, VCOL0, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 * cs_c + 2, VCOL1, 0 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * cs_c + 2, VCOL2, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 3 * cs_c + 2, VCOL3, 0 ); + } + } + if ( m0 > 4 ) + { + VCOL0 = vtrn1q_f64(vc_40, vc_50); + VCOL1 = vtrn2q_f64(vc_40, vc_50); + VCOL2 = vtrn1q_f64(vc_41, vc_51); + VCOL3 = vtrn2q_f64(vc_41, vc_51); + + if ( m0 > 5 ) + { + if ( n > 0 ) VTMP0 = vld1q_f64( c_loc + 0 * cs_c + 4 ); + if ( n > 1 ) VTMP1 = vld1q_f64( c_loc + 1 * cs_c + 4 ); + if ( n > 2 ) VTMP2 = vld1q_f64( c_loc + 2 * cs_c + 4 ); + if ( n > 3 ) VTMP3 = vld1q_f64( c_loc + 3 * cs_c + 4 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_f64( c_loc + 0 * cs_c + 4, VCOL0 ); + if ( n > 1 ) vst1q_f64( c_loc + 1 * cs_c + 4, VCOL1 ); + if ( n > 2 ) vst1q_f64( c_loc + 2 * cs_c + 4, VCOL2 ); + if ( n > 3 ) vst1q_f64( c_loc + 3 * cs_c + 4, VCOL3 ); + } + else + { + if ( n > 0 ) VTMP0 = vld1q_lane_f64( c_loc + 0 * cs_c + 4, VTMP0, 0 ); + if ( n > 1 ) VTMP1 = vld1q_lane_f64( c_loc + 1 * cs_c + 4, VTMP1, 0 ); + if ( n > 2 ) VTMP2 = vld1q_lane_f64( c_loc + 2 * cs_c + 4, VTMP2, 0 ); + if ( n > 3 ) VTMP3 = vld1q_lane_f64( c_loc + 3 * cs_c + 4, VTMP3, 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 * cs_c + 4, VCOL0, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 * cs_c + 4, VCOL1, 0 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * cs_c + 4, VCOL2, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 3 * cs_c + 4, VCOL3, 0 ); + } + } + } + + b_in += ps_b; + c_in += 4 * cs_c; + } + + a0 += ps_a; + c0 += 6 * rs_c; + } +} + diff --git a/kernels/armv8a/bli_kernels_armv8a.h b/kernels/armv8a/bli_kernels_armv8a.h index f3c01985a9..b7ab755412 100644 --- a/kernels/armv8a/bli_kernels_armv8a.h +++ b/kernels/armv8a/bli_kernels_armv8a.h @@ -32,5 +32,30 @@ */ +PACKM_KER_PROT( float, s, packm_armv8a_int_8xk ) +PACKM_KER_PROT( float, s, packm_armv8a_int_12xk ) +PACKM_KER_PROT( double, d, packm_armv8a_int_6xk ) +PACKM_KER_PROT( double, d, packm_armv8a_int_8xk ) + GEMM_UKR_PROT( float, s, gemm_armv8a_asm_8x12 ) GEMM_UKR_PROT( double, d, gemm_armv8a_asm_6x8 ) +// GEMM_UKR_PROT( double, d, gemm_armv8a_asm_6x8r ) +// GEMM_UKR_PROT( double, d, gemm_armv8a_asm_8x4 ) +// GEMM_UKR_PROT( double, d, gemm_armv8a_asm_4x4 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_4x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_4x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_8x4m ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_int_2x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_int_3x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_3x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_6x3 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_int_6x4mn ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_int_3x8mn ) + diff --git a/kernels/bgq/1/bli_axpyv_bgq_int.c b/kernels/bgq/1/bli_axpyv_bgq_int.c index aa83d6b5a1..0c4a8cbd3c 100644 --- a/kernels/bgq/1/bli_axpyv_bgq_int.c +++ b/kernels/bgq/1/bli_axpyv_bgq_int.c @@ -48,7 +48,7 @@ void bli_daxpyv_bgq_int // If there is anything that would interfere with our use of aligned // vector loads/stores, call the reference implementation. - bool_t use_ref = FALSE; + bool use_ref = FALSE; if ( incx != 1 || incy != 1 || bli_is_unaligned_to( ( siz_t )x, 32 ) || bli_is_unaligned_to( ( siz_t )y, 32 ) ) { use_ref = TRUE; } diff --git a/kernels/bgq/1/bli_dotv_bgq_int.c b/kernels/bgq/1/bli_dotv_bgq_int.c index d54f337f71..73e53c23a1 100644 --- a/kernels/bgq/1/bli_dotv_bgq_int.c +++ b/kernels/bgq/1/bli_dotv_bgq_int.c @@ -45,7 +45,7 @@ void bli_ddotv_bgq_int cntx_t* restrict cntx ) { - bool_t use_ref = FALSE; + bool use_ref = FALSE; // If the vector lengths are zero, set rho to zero and return. if ( bli_zero_dim1( n ) ) { diff --git a/kernels/bgq/1f/bli_axpyf_bgq_int.c b/kernels/bgq/1f/bli_axpyf_bgq_int.c index 056daf2bb1..4e296e0a25 100644 --- a/kernels/bgq/1f/bli_axpyf_bgq_int.c +++ b/kernels/bgq/1f/bli_axpyf_bgq_int.c @@ -52,7 +52,7 @@ void bli_daxpyf_bgq_int if ( bli_zero_dim2( m, b_n ) ) return; - bool_t use_ref = FALSE; + bool use_ref = FALSE; // printf("%d\t%d\t%d\t%d\t%d\t%d\t%d\n", b_n, fusefac, inca, incx, incy, bli_is_unaligned_to( ( siz_t )a, 32 ), bli_is_unaligned_to( ( siz_t )y, 32)); // If there is anything that would interfere with our use of aligned // vector loads/stores, call the reference implementation. diff --git a/kernels/bgq/3/bli_gemm_bgq_int_8x8.c b/kernels/bgq/3/bli_gemm_bgq_int_8x8.c index 1612e69b0d..15e3e072f3 100644 --- a/kernels/bgq/3/bli_gemm_bgq_int_8x8.c +++ b/kernels/bgq/3/bli_gemm_bgq_int_8x8.c @@ -56,6 +56,8 @@ void bli_dgemm_bgq_int_8x8 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -66,6 +68,8 @@ void bli_dgemm_bgq_int_8x8 cntx_t* restrict cntx ) { + GEMM_UKR_SETUP_CT_ANY( d, 8, 8, false ); + //Registers for storing C. //4 4x4 subblocks of C, c00, c01, c10, c11 //4 registers per subblock: a, b, c, d @@ -201,6 +205,8 @@ void bli_dgemm_bgq_int_8x8 UPDATE( AB, c, 0 ); AB = vec_perm( c11d, c11d, pattern ); UPDATE( AB, c, 4 ); + + GEMM_UKR_FLUSH_CT( d ); } void printvec(vector4double v) @@ -214,6 +220,8 @@ void printvec(vector4double v) void bli_zgemm_bgq_int_4x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, @@ -224,6 +232,8 @@ void bli_zgemm_bgq_int_4x4 cntx_t* restrict cntx ) { + GEMM_UKR_SETUP_CT_ANY( z, 4, 4, false ); + double* a_d = ( double* )a; double* b_d = ( double* )b; double* c_d = ( double* )c; @@ -368,4 +378,6 @@ void bli_zgemm_bgq_int_4x4 c_d += 2*cs_c; ZUPDATE( c03a, c03b, c_d, 0 ); ZUPDATE( c13a, c13b, c_d, 4 ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 403aaaaeef..3a75d61d73 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -90,7 +90,9 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -102,25 +104,27 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 8, false, 32 ); + begin_asm() - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; - + lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c @@ -130,7 +134,7 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 prefetch(0, mem(r10, rdi, 1, 7*8)) // prefetch c + 5*cs_c prefetch(0, mem(r10, rdi, 2, 7*8)) // prefetch c + 6*cs_c prefetch(0, mem(r10, r14, 1, 7*8)) // prefetch c + 7*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -139,15 +143,15 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.SLOOPKITER) // MAIN LOOP - + // iteration 0 prefetch(0, mem(rax, 16*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) @@ -155,44 +159,44 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmovshdup(mem(rbx, 0*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 1*32), ymm1) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) - + // iteration 1 vfmaddps(ymm15, ymm1, ymm2, ymm15) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovshdup(mem(rbx, 1*32), ymm2) vfmaddps(ymm13, ymm1, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 2*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm1, ymm4, ymm11) vfmaddps(ymm9, ymm1, ymm5, ymm9) - + vfmaddps(ymm14, ymm1, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 2*32), ymm2) vfmaddps(ymm12, ymm1, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm1, ymm4, ymm10) vfmaddps(ymm8, ymm1, ymm5, ymm8) - + // iteration 2 prefetch(0, mem(rax, 18*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) @@ -200,23 +204,23 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmovshdup(mem(rbx, 2*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 3*32), ymm1) add(imm(4*8*4), rax) // a += 4*8 (unroll x mr) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 3*32), ymm2) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) - + // iteration 3 vfmaddps(ymm15, ymm1, ymm2, ymm15) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -224,134 +228,134 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 add(imm(4*8*4), rbx) // b += 4*8 (unroll x nr) vfmaddps(ymm13, ymm1, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 0*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm1, ymm4, ymm11) vfmaddps(ymm9, ymm1, ymm5, ymm9) - + vfmaddps(ymm14, ymm1, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 0*32), ymm2) vfmaddps(ymm12, ymm1, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm1, ymm4, ymm10) vfmaddps(ymm8, ymm1, ymm5, ymm8) - - - + + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 16*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovshdup(mem(rbx, 0*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 1*32), ymm1) add(imm(8*1*4), rax) // a += 8 (1 x mr) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - - vfmaddps(ymm14, ymm0, ymm2, ymm14) + + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) add(imm(8*1*4), rbx) // b += 8 (1 x nr) - vfmaddps(ymm12, ymm0, ymm3, ymm12) + vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) - vfmaddps(ymm10, ymm0, ymm4, ymm10) - vfmaddps(ymm8, ymm0, ymm5, ymm8) + vfmaddps(ymm10, ymm0, ymm4, ymm10) + vfmaddps(ymm8, ymm0, ymm5, ymm8) vmovaps(ymm1, ymm0) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - + + label(.SPOSTACCUM) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab22 ab20 ab26 ab24 // ab32 ab30 ab36 ab34 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab66 ab64 ab62 ab60 // ab76 ) ab74 ) ab72 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab23 ab21 ab27 ab25 // ab33 ab31 ab37 ab35 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab67 ab65 ab63 ab61 // ab77 ) ab75 ) ab73 ) ab71 ) GROUP_YMM_BY_4 // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab64 ab66 ab60 ab62 // ab74 ) ab76 ) ab70 ) ab72 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab65 ab67 ab61 ab63 // ab75 ) ab77 ) ab71 ) ab73 ) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab40 ab42 ab44 ab46 - // ab50 ab52 ab54 ab56 + // ab50 ab52 ab54 ab56 // ab60 ab62 ab64 ab66 // ab70 ) ab72 ) ab74 ) ab76 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab41 ab43 ab45 ab47 - // ab51 ab53 ab55 ab57 + // ab51 ab53 ab55 ab57 // ab61 ab63 ab65 ab67 // ab71 ) ab73 ) ab75 ) ab77 ) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm4) // load beta and duplicate - + vmulps(ymm0, ymm8, ymm8) // scale by alpha vmulps(ymm0, ymm9, ymm9) vmulps(ymm0, ymm10, ymm10) @@ -360,401 +364,115 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - - // determine if - // c % 32 == 0, AND - // 4*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (4*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - label(.SGENSTORED) - // update c00:c70 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm14, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm14, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm12, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm12, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm10, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm10, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm8, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm8, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vmovaps(mem(rcx), ymm0) // load c00:c70, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c01:c71, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c02:c72, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c03:c73, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c04:c74, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c05:c75, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c06:c76, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c07:c77, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - - jmp(.SDONE) // jump to end. - - + + vmovaps(mem(rcx), ymm0) // load c00:c70, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c01:c71, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c02:c72, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c03:c73, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c04:c74, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c05:c75, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c06:c76, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c07:c77, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - label(.SGENSTORBZ) - // update c00:c70 - vmovapd(ymm15, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vmovapd(ymm14, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vmovapd(ymm13, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vmovapd(ymm12, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c04:c74 - vmovapd(ymm11, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c05:c75 - vmovapd(ymm10, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c06:c76 - vmovapd(ymm9, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c07:c77 - vmovapd(ymm8, ymm0) - STORE_SS - - jmp(.SDONE) // jump to end. - - - label(.SCOLSTORBZ) - - vmovaps(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm8, mem(rcx)) // and store back to memory. - + + vmovaps(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm8, mem(rcx)) // and store back to memory. + label(.SDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -762,6 +480,8 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } #undef KERNEL4x6_1 @@ -862,7 +582,9 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 void bli_dgemm_bulldozer_asm_4x6_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -874,66 +596,68 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 12; - uint64_t k_left = k0 % 12; + uint64_t k_iter = k / 12; + uint64_t k_left = k % 12; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ANY( d, 4, 6, false ); + begin_asm() - - + + vzeroall() mov(var(b), rbx) // load address of b. mov(var(a), rax) // load address of a. prefetch(0, mem(rax, 64)) - - + + vmovaps(mem(rbx, 0*8), xmm1) vmovaps(mem(rbx, 2*8), xmm2) vmovaps(mem(rbx, 4*8), xmm3) add(imm(12*8), rbx) add(imm(8*8), rax) - + mov(var(k_iter), rsi) // i = k_iter; notice var(k_iter) not $0 test(rsi, rsi) je(.CONSIDERKLEFT) - + ALIGN32 label(.LOOPKITER) // MAIN LOOP - - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - + + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + dec(rsi) jne(.LOOPKITER) - + label(.CONSIDERKLEFT) - + mov(var(k_left), rsi) - test(rsi, rsi) + test(rsi, rsi) label(.LOOPKLEFT) je(.POSTACCUM) - - KERNEL4x6_1(xx) + + KERNEL4x6_1(xx) add(imm(6*8), rbx) add(imm(4*8), rax) - + dec(rsi) jmp(.LOOPKLEFT) // iterate again if i != 0. - + label(.POSTACCUM) - - + + mov(var(rs_c), rsi) // load cs_c mov(var(cs_c), rdi) // load rs_c vmovddup(mem(var(alpha)), xmm2) //load alpha @@ -942,32 +666,32 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 sal(imm(3), rsi) // cs_c *= sizeof(double) sal(imm(3), rdi) // rs_c *= sizeof(double) lea(mem(rcx, rdi, 2), rdx) - - vmovlpd(mem(rcx), xmm0, xmm0) - vmovlpd(mem(rdx), xmm1, xmm1) + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovlpd(mem(rdx), xmm1, xmm1) vmovhpd(mem(rcx, rdi, 1), xmm0, xmm0) vmovhpd(mem(rdx, rdi, 1), xmm1, xmm1) lea(mem(rdx, rdi, 2), r8) vmulpd(xmm2, xmm4, xmm4) // scale by alpha, vmulpd(xmm2, xmm5, xmm5) // scale by alpha, vfmaddpd(xmm4, xmm0, xmm3, xmm4) // scale by beta, and add the gemm result - vmovlpd(mem(r8), xmm0, xmm0) + vmovlpd(mem(r8), xmm0, xmm0) vfmaddpd(xmm5, xmm1, xmm3, xmm5) // scale by beta, and add the gemm result vmovhpd(mem(r8, rdi, 1), xmm0, xmm0) vmovlpd(xmm4, mem(rcx)) // and store back to memory. vmovlpd(xmm5, mem(rdx)) // and store back to memory. vmovhpd(xmm4, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm5, mem(rdx, rdi, 1)) - add(rsi, rdx) - + add(rsi, rdx) + vmulpd(xmm2, xmm6, xmm6) // scale by alpha, vfmaddpd(xmm6, xmm0, xmm3, xmm6) // scale by beta, and add the gemm result vmovlpd(xmm6, mem(r8)) // and store back to memory. vmovhpd(xmm6, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -984,13 +708,13 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm8, mem(rdx)) // and store back to memory. vmovlpd(xmm9, mem(r8)) // and store back to memory. vmovhpd(xmm7, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm8, mem(rdx, rdi, 1)) - add(rsi, rdx) + add(rsi, rdx) vmovhpd(xmm9, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -1007,13 +731,13 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm11, mem(rdx)) // and store back to memory. vmovlpd(xmm12, mem(r8)) // and store back to memory. vmovhpd(xmm10, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm11, mem(rdx, rdi, 1)) - add(rsi, rdx) + add(rsi, rdx) vmovhpd(xmm12, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -1031,30 +755,32 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm15, mem(r8)) // and store back to memory. vmovhpd(xmm13, mem(rcx, rdi, 1)) vmovhpd(xmm14, mem(rdx, rdi, 1)) - vmovhpd(xmm15, mem(r8, rdi, 1)) - - end_asm( - : // output operands (none) - : // input operands - [k_iter] "r" (k_iter), // 0 - [k_left] "r" (k_left), // 1 - [a] "r" (a), // 2 - [b] "r" (b), // 3 - [alpha] "r" (alpha), // 4 - [beta] "r" (beta), // 5 - [c] "r" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + vmovhpd(xmm15, mem(r8, rdi, 1)) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "r" (k_iter), // 0 + [k_left] "r" (k_left), // 1 + [a] "r" (a), // 2 + [b] "r" (b), // 3 + [alpha] "r" (alpha), // 4 + [beta] "r" (beta), // 5 + [c] "r" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } //The parameter "i" is the iteration number, i.e. the B values to read #define MADD_TO_YMM(i) \ @@ -1076,7 +802,9 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 void bli_cgemm_bulldozer_asm_8x4_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1091,33 +819,35 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( c, 8, 4, false, 32 ); + begin_asm() - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -1126,343 +856,312 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.CLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) MADD_TO_YMM(0) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vaddsubps(ymm6, ymm15, ymm15) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 1 prefetch(0, mem(rax, 10*32)) vmovaps(mem(rax, 3*32), ymm1) MADD_TO_YMM(1) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 4*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - + // iteration 2 prefetch(0, mem(rax, 12*32)) vmovaps(mem(rax, 5*32), ymm1) MADD_TO_YMM(2) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 6*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 3 prefetch(0, mem(rax, 14*32)) vmovaps(mem(rax, 7*32), ymm1) MADD_TO_YMM(3) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 4*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 8*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*4*8), rax) // a += 8*4 (unroll x mr) add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) MADD_TO_YMM(0) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*1*8), rax) // a += 8 (1 x mr) add(imm(4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab21 ab20 ab23 ab22 - // ab31 ab30 ab33 ab32 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab63 ab62 ab61 ab60 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab21 ab20 ab23 ab22 + // ab31 ab30 ab33 ab32 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba1 aba0 aba3 aba2 - // abb1 abb0 abb3 abb2 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe3 abe2 abe1 abe0 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba1 aba0 aba3 aba2 + // abb1 abb0 abb3 abb2 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe3 abe2 abe1 abe0 // abf3 abf2 abf1 abf0 ) GROUP_YMM_BY_4 // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab62 ab63 ab60 ab61 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab62 ab63 ab60 ab61 // ab72 ) ab73 ) ab70 ) ab71 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe2 abe3 abe0 abe1 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe2 abe3 abe0 abe1 // abf2 ) abf3 ) abf0 ) abf1 ) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab40 ab41 ab42 ab43 - // ab50 ab51 ab52 ab53 - // ab60 ab61 ab62 ab63 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab40 ab41 ab42 ab43 + // ab50 ab51 ab52 ab53 + // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc0 abc1 abc2 abc3 - // abd0 abd1 abd2 abd3 - // abe0 abe1 abe2 abe3 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc0 abc1 abc2 abc3 + // abd0 abd1 abd2 abd3 + // abe0 abe1 abe2 abe3 // abf0 ) abf1 ) abf2 ) abf3 ) - + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm6) // load alpha_i and duplicate - + vpermilps(imm(0xb1), ymm15, ymm3) vmulps(ymm7, ymm15, ymm15) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm15, ymm15) - + vpermilps(imm(0xb1), ymm14, ymm2) vmulps(ymm7, ymm14, ymm14) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm14, ymm14) - + vpermilps(imm(0xb1), ymm13, ymm1) vmulps(ymm7, ymm13, ymm13) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm13, ymm13) - + vpermilps(imm(0xb1), ymm12, ymm0) vmulps(ymm7, ymm12, ymm12) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm11, ymm3) vmulps(ymm7, ymm11, ymm11) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm11, ymm11) - + vpermilps(imm(0xb1), ymm10, ymm2) vmulps(ymm7, ymm10, ymm10) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm10, ymm10) - + vpermilps(imm(0xb1), ymm9, ymm1) vmulps(ymm7, ymm9, ymm9) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm9, ymm9) - + vpermilps(imm(0xb1), ymm8, ymm0) vmulps(ymm7, ymm8, ymm8) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm6) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -1470,388 +1169,126 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - // update c00:c70 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c00,10) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c20,30) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c40,50) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c60,70) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c00,c10) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c80,90) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca0,b0) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc0,d0) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce0,f0) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c80,c90) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c01,11) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c21,31) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c41,51) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c61,71) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c01,c11) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c81,91) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca1,b1) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc1,d1) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce1,f1) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c81,c91) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c02,12) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c22,32) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c42,52) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c62,72) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c02,c12) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c82,92) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca2,b2) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc2,d2) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce2,f2) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c82,c92) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c03,13) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c23,33) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c43,53) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c63,73) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c03,c13) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c83,93) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca3,b3) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc3,d3) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce3,f3) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c83,c93) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovaps(mem(rdx), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovaps(mem(rdx), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - vmovaps(mem(rdx), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - vmovaps(mem(rdx), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - jmp(.CDONE) // jump to end. - - + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c00:c70 + + // update c80:cf0 + + vmovaps(mem(rcx,32), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c01:c71 + + // update c81:cf1 + + vmovaps(mem(rcx,32), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + // update c02:c72 + vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c02:c72 + + // update c82:cf2 + vmovaps(mem(rcx,32), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + // update c03:c73 + vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c03:c73 + + // update c83:cf3 + vmovaps(mem(rcx,32), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c83:cf3 + //add(rdi, rcx) // c += cs_c; + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORBZ) // jump to column storage case - - - label(.CGENSTORBZ) - // update c00:c70 - vextractf128(imm(1), ymm15, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm15, mem(rcx)) // store (c00,c10) - vmovhpd(xmm15, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - vextractf128(imm(1), ymm14, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm14, mem(rdx)) // store (c80,c90) - vmovhpd(xmm14, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vextractf128(imm(1), ymm13, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm13, mem(rcx)) // store (c01,c11) - vmovhpd(xmm13, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - vextractf128(imm(1), ymm12, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm12, mem(rdx)) // store (c81,c91) - vmovhpd(xmm12, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vextractf128(imm(1), ymm11, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm11, mem(rcx)) // store (c02,c12) - vmovhpd(xmm11, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - vextractf128(imm(1), ymm10, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm10, mem(rdx)) // store (c82,c92) - vmovhpd(xmm10, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vextractf128(imm(1), ymm9, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm9, mem(rcx)) // store (c03,c13) - vmovhpd(xmm9, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - vextractf128(imm(1), ymm8, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm8, mem(rdx)) // store (c83,c93) - vmovhpd(xmm8, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - jmp(.CDONE) // jump to end. - - - label(.CCOLSTORBZ) - - vmovaps(ymm15, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm14, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm13, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm12, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm11, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm10, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm9, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm8, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - + + vmovaps(ymm15, mem(rcx)) // store c00:c70 + vmovaps(ymm14, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm13, mem(rcx)) // store c01:c71 + vmovaps(ymm12, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm11, mem(rcx)) // store c02:c72 + vmovaps(ymm10, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm9, mem(rcx)) // store c03:c73 + vmovaps(ymm8, mem(rcx,32)) // store c83:cf3 + add(rdi, rcx) // c += cs_c; + label(.CDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", @@ -1859,6 +1296,8 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } #define MADDSUBPD_TO_YMM \ @@ -1883,11 +1322,13 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 vmulpd(ymm7, ymm(i), ymm(i))\ vmulpd(ymm6, ymm(j), ymm(j))\ vaddsubpd(ymm(j), ymm(i), ymm(i))\ - + void bli_zgemm_bulldozer_asm_4x4_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -1902,34 +1343,36 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( z, 4, 4, false, 32 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovddup(mem(rbx, 0+0*32), ymm2) vmovddup(mem(rbx, 0+1*32), ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -1938,28 +1381,28 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.ZLOOPKITER) // MAIN LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 16*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+0*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+1*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) @@ -1967,31 +1410,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 18*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+2*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+3*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+4*32), ymm2) @@ -1999,31 +1442,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+5*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 20*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+4*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+5*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+6*32), ymm2) @@ -2031,31 +1474,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+7*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 22*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+6*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+7*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+8*32), ymm2) @@ -2063,48 +1506,48 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+9*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 8*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) - + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 16*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+0*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+1*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) @@ -2112,75 +1555,75 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*1*16), rax) // a += 4 (1 x mr) add(imm(4*1*16), rbx) // b += 4 (1 x nr) - + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - + + label(.ZPOSTACCUM) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab21 ab20 ab23 ab22 // ab31 ) ab30 ) ab33 ) ab32 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab61 ab60 ab63 ab62 // ab71 ) ab70 ) ab73 ) ab72 ) - + vmovapd(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm13, ymm15) vperm2f128(imm(0x30), ymm7, ymm13, ymm13) - + vmovapd(ymm11, ymm7) vperm2f128(imm(0x12), ymm11, ymm9, ymm11) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm12, ymm14) vperm2f128(imm(0x30), ymm7, ymm12, ymm12) - + vmovapd(ymm10, ymm7) vperm2f128(imm(0x12), ymm10, ymm8, ymm10) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm6) // load alpha_i and duplicate - + Z_ALPHA(15, 3) Z_ALPHA(14, 2) Z_ALPHA(13, 1) @@ -2190,38 +1633,14 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 Z_ALPHA(10, 2) Z_ALPHA(9, 1) Z_ALPHA(8, 0) - + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm6) // load beta_i and duplicate - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 16*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (16*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2229,287 +1648,91 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - // update c00:c30 - - vmovupd(mem(rcx), xmm0) // load (c00,c10) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c20,c30) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), xmm0) // load (c40,c50) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c60,c70) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), xmm0) // load (c01,c11) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c21,c31) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), xmm0) // load (c41,c51) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c61,c71) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), xmm0) // load (c02,c12) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c22,c32) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), xmm0) // load (c42,c52) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c62,c72) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), xmm0) // load (c03,c13) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c23,c33) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), xmm0) // load (c43,c53) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c63,c73) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - // update c00:c30 - - vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovapd(mem(rdx), ymm0) // load c40:c70 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovapd(mem(rdx), ymm0) // load c41:c71 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovapd(mem(rdx), ymm0) // load c42:c72 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovapd(mem(rdx), ymm0) // load c43:c73 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c43:c73 - - - - jmp(.ZDONE) // jump to end. - - - + + // update c00:c30 + + vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c00:c30 + + // update c40:c70 + + vmovapd(mem(rcx,32), ymm0) // load c40:c70 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; + + // update c01:c31 + + vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c01:c31 + + // update c41:c71 + + vmovapd(mem(rcx,32), ymm0) // load c41:c71 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; + + // update c02:c32 + + vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c02:c32 + + // update c42:c72 + + vmovapd(mem(rcx,32), ymm0) // load c42:c72 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; + + // update c03:c33 + + vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c03:c33 + + // update c43:c73 + + vmovapd(mem(rcx,32), ymm0) // load c43:c73 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c43:c73 + add(rdi, rcx) // c += cs_c; + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - // update c00:c30 - - vextractf128(imm(1), ymm15, xmm2) - vmovupd(xmm15, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vextractf128(imm(1), ymm14, xmm2) - vmovupd(xmm14, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vextractf128(imm(1), ymm13, xmm2) - vmovupd(xmm13, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vextractf128(imm(1), ymm12, xmm2) - vmovupd(xmm12, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vextractf128(imm(1), ymm11, xmm2) - vmovupd(xmm11, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vextractf128(imm(1), ymm10, xmm2) - vmovupd(xmm10, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vextractf128(imm(1), ymm9, xmm2) - vmovupd(xmm9, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vextractf128(imm(1), ymm8, xmm2) - vmovupd(xmm8, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - jmp(.ZDONE) // jump to end. - - - label(.ZCOLSTORBZ) - - - vmovapd(ymm15, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm14, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm13, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm12, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm11, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm10, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm9, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm8, mem(rdx)) // store c43:c73 - - + + vmovapd(ymm15, mem(rcx)) // store c00:c30 + vmovapd(ymm14, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm13, mem(rcx)) // store c01:c31 + vmovapd(ymm12, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm11, mem(rcx)) // store c02:c32 + vmovapd(ymm10, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm9, mem(rcx)) // store c03:c33 + vmovapd(ymm8, mem(rcx,32)) // store c43:c73 + //add(rdi, rcx) // c += cs_c; + label(.ZDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -2524,7 +1747,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", @@ -2532,5 +1755,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c new file mode 100644 index 0000000000..843335ad5d --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -0,0 +1,397 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( scomplex, c, packm_3xk_haswell_ref ) + +void bli_cpackm_haswell_asm_3xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + scomplex* restrict kappa, + scomplex* restrict a, inc_t inca0, inc_t lda0, + scomplex* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_cpackm_3xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 3; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 6; + + // Define a local copy of 1.0 so we can test for unit kappa. + float one_l = 1.0; + float* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 4; +#if 1 + const uint64_t k_left = k0 % 4; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_ceq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && !conja && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem(, r8, 8), r8) // inca *= sizeof(scomplex) + lea(mem(, r10, 8), r10) // lda *= sizeof(scomplex) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 4), r14) // r14 = 4*lda + + mov(var(one), rdx) // load address of 1.0 constant + vbroadcastss(mem(rdx, 0), ymm1) // load 1.0 and duplicate + vxorps(ymm0, ymm0, ymm0) // set ymm0 to 0.0. + + mov(var(kappa), rcx) // load address of kappa + vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate + vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate + + + // now branch on kappa == 1.0 + + vucomiss(xmm1, xmm10) // set ZF if kappa_r == 1.0. + sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm11) // set ZF if kappa_i == 0.0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + and(r12b, r13b) // set ZF if r12b & r13b == 1. + jne(.CKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.CKAPPANONU) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.CCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.CROWNONU) + + jmp(.CDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.CCOLNONU) + + jmp(.CDONE) // jump to end. + + + + + label(.CKAPPAUNIT) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.CCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.CROWUNIT) + + //lea(mem(r8, r8, 2), r12) // r12 = 3*inca + //lea(mem(r12, r8, 2), rcx) // rcx = 5*inca + //lea(mem(r12, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.CCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.CKITERROWU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, r8, 1, 0), ymm2) + vmovupd(mem(rax, r8, 2, 0), ymm4) + + add(r14, rax) // a += 4*lda; + + vunpcklpd(ymm2, ymm0, ymm10) + vunpckhpd(ymm2, ymm0, ymm11) + vunpcklpd(ymm6, ymm4, ymm12) + vunpckhpd(ymm6, ymm4, ymm13) + vinsertf128(imm(0x1), xmm12, ymm10, ymm0) + vinsertf128(imm(0x1), xmm13, ymm11, ymm2) + vperm2f128(imm(0x31), ymm12, ymm10, ymm4) + vperm2f128(imm(0x31), ymm13, ymm11, ymm6) + + vextractf128(imm(0x1), ymm0, xmm1) + vextractf128(imm(0x1), ymm2, xmm3) + vextractf128(imm(0x1), ymm4, xmm5) + vextractf128(imm(0x1), ymm6, xmm7) + + vmovupd(xmm0, mem(rbx, 0*24)) + vmovupd(xmm2, mem(rbx, 1*24)) + vmovupd(xmm4, mem(rbx, 2*24)) + vmovupd(xmm6, mem(rbx, 3*24)) + + vmovsd(xmm1, mem(rbx, 0*24+16)) + vmovsd(xmm3, mem(rbx, 1*24+16)) + vmovsd(xmm5, mem(rbx, 2*24+16)) + vmovsd(xmm7, mem(rbx, 3*24+16)) + + add(imm(4*3*8), rbx) // p += 4*ldp = 4*3; + + dec(rsi) // i -= 1; + jne(.CKITERROWU) // iterate again if i != 0. + + + + label(.CCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.CDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.CKLEFTROWU) // EDGE LOOP (k_left) + + vmovsd(mem(rax, 0), xmm0) + vmovsd(mem(rax, r8, 1, 0), xmm2) + vmovsd(mem(rax, r8, 2, 0), xmm4) + + add(r10, rax) // a += lda; + + vmovsd(xmm0, mem(rbx, 0*8)) + vmovsd(xmm2, mem(rbx, 1*8)) + vmovsd(xmm4, mem(rbx, 2*8)) + + add(imm(3*8), rbx) // p += ldp = 3; + + dec(rsi) // i -= 1; + jne(.CKLEFTROWU) // iterate again if i != 0. + + + jmp(.CDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.CCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.CCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.CKITERCOLU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), xmm0) + vmovsd( mem(rax, 16), xmm1) + vmovupd(xmm0, mem(rbx, 0*24+ 0)) + vmovsd( xmm1, mem(rbx, 0*24+16)) + + vmovupd(mem(rax, r10, 1, 0), xmm2) + vmovsd( mem(rax, r10, 1, 16), xmm3) + vmovupd(xmm2, mem(rbx, 1*24+ 0)) + vmovsd( xmm3, mem(rbx, 1*24+16)) + + vmovupd(mem(rax, r10, 2, 0), xmm4) + vmovsd( mem(rax, r10, 2, 16), xmm5) + vmovupd(xmm4, mem(rbx, 2*24+ 0)) + vmovsd( xmm5, mem(rbx, 2*24+16)) + + vmovupd(mem(rax, r13, 1, 0), xmm6) + vmovsd( mem(rax, r13, 1, 16), xmm7) + add(r14, rax) // a += 4*lda; + vmovupd(xmm6, mem(rbx, 3*24+ 0)) + vmovsd( xmm7, mem(rbx, 3*24+16)) + add(imm(4*3*8), rbx) // p += 4*ldp = 4*3; + + dec(rsi) // i -= 1; + jne(.CKITERCOLU) // iterate again if i != 0. + + + + label(.CCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.CDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.CKLEFTCOLU) // EDGE LOOP (k_left) + + vmovupd(mem(rax, 0), xmm0) + vmovsd( mem(rax, 16), xmm1) + add(r10, rax) // a += lda; + vmovupd(xmm0, mem(rbx, 0*24+ 0)) + vmovsd( xmm1, mem(rbx, 0*24+16)) + add(imm(3*8), rbx) // p += ldp = 3; + + dec(rsi) // i -= 1; + jne(.CKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.CDONE) // jump to end. + + + + label(.CDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk ) + { + PASTEMAC(cscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + scomplex* restrict p_edge = p + (i )*1; + + bli_cset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + +//bli_dfprintm( stdout, "packm 6xk ker: a_packed", cdim0, k0_max, p, 1, ldp0, "%5.2f", "" ); + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + scomplex* restrict p_edge = p + (j )*ldp; + + bli_cset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c new file mode 100644 index 0000000000..862a33b86a --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -0,0 +1,415 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( scomplex, c, packm_8xk_haswell_ref ) + +void bli_cpackm_haswell_asm_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + scomplex* restrict kappa, + scomplex* restrict a, inc_t inca0, inc_t lda0, + scomplex* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_cpackm_8xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 8; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 8; + + // Define a local copy of 1.0 so we can test for unit kappa. + float one_l = 1.0; + float* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 4; +#if 1 + const uint64_t k_left = k0 % 4; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_ceq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && !conja && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem(, r8, 8), r8) // inca *= sizeof(scomplex) + lea(mem(, r10, 8), r10) // lda *= sizeof(scomplex) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 4), r14) // r14 = 4*lda + + mov(var(one), rdx) // load address of 1.0 constant + vbroadcastss(mem(rdx, 0), ymm1) // load 1.0 and duplicate + vxorps(ymm0, ymm0, ymm0) // set ymm0 to 0.0. + + mov(var(kappa), rcx) // load address of kappa + vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate + vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate + + + // now branch on kappa == 1.0 + + vucomiss(xmm1, xmm10) // set ZF if kappa_r == 1.0. + sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm11) // set ZF if kappa_i == 0.0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + and(r12b, r13b) // set ZF if r12b & r13b == 1. + jne(.CKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.CKAPPANONU) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.CCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.CROWNONU) + + jmp(.CDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.CCOLNONU) + + jmp(.CDONE) // jump to end. + + + + + label(.CKAPPAUNIT) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.CCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.CROWUNIT) + + lea(mem(r8, r8, 2), r12) // r12 = 3*inca + lea(mem(r12, r8, 2), rcx) // rcx = 5*inca + lea(mem(r12, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.CCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.CKITERROWU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, r8, 1, 0), ymm2) + vmovupd(mem(rax, r8, 2, 0), ymm4) + vmovupd(mem(rax, r12, 1, 0), ymm6) + + vunpcklpd(ymm2, ymm0, ymm10) + vunpckhpd(ymm2, ymm0, ymm11) + vunpcklpd(ymm6, ymm4, ymm12) + vunpckhpd(ymm6, ymm4, ymm13) + vinsertf128(imm(0x1), xmm12, ymm10, ymm0) + vinsertf128(imm(0x1), xmm13, ymm11, ymm2) + vperm2f128(imm(0x31), ymm12, ymm10, ymm4) + vperm2f128(imm(0x31), ymm13, ymm11, ymm6) + + vmovupd(ymm0, mem(rbx, 0*64)) + vmovupd(ymm2, mem(rbx, 1*64)) + vmovupd(ymm4, mem(rbx, 2*64)) + vmovupd(ymm6, mem(rbx, 3*64)) + + vmovupd(mem(rax, r8, 4, 0), ymm1) + vmovupd(mem(rax, rcx, 1, 0), ymm3) + vmovupd(mem(rax, r12, 2, 0), ymm5) + vmovupd(mem(rax, rdx, 1, 0), ymm7) + + add(r14, rax) // a += 4*lda; + + vunpcklpd(ymm3, ymm1, ymm10) + vunpckhpd(ymm3, ymm1, ymm11) + vunpcklpd(ymm7, ymm5, ymm12) + vunpckhpd(ymm7, ymm5, ymm13) + vinsertf128(imm(0x1), xmm12, ymm10, ymm1) + vinsertf128(imm(0x1), xmm13, ymm11, ymm3) + vperm2f128(imm(0x31), ymm12, ymm10, ymm5) + vperm2f128(imm(0x31), ymm13, ymm11, ymm7) + + vmovupd(ymm1, mem(rbx, 0*64+32)) + vmovupd(ymm3, mem(rbx, 1*64+32)) + vmovupd(ymm5, mem(rbx, 2*64+32)) + vmovupd(ymm7, mem(rbx, 3*64+32)) + + add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; + + dec(rsi) // i -= 1; + jne(.CKITERROWU) // iterate again if i != 0. + + + + label(.CCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.CDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.CKLEFTROWU) // EDGE LOOP (k_left) + + vmovsd(mem(rax, 0), xmm0) + vmovsd(mem(rax, r8, 1, 0), xmm2) + vmovsd(mem(rax, r8, 2, 0), xmm4) + vmovsd(mem(rax, r12, 1, 0), xmm6) + vmovsd(mem(rax, r8, 4, 0), xmm1) + vmovsd(mem(rax, rcx, 1, 0), xmm3) + vmovsd(mem(rax, r12, 2, 0), xmm5) + vmovsd(mem(rax, rdx, 1, 0), xmm7) + + add(r10, rax) // a += lda; + + vmovsd(xmm0, mem(rbx, 0*8)) + vmovsd(xmm2, mem(rbx, 1*8)) + vmovsd(xmm4, mem(rbx, 2*8)) + vmovsd(xmm6, mem(rbx, 3*8)) + vmovsd(xmm1, mem(rbx, 4*8)) + vmovsd(xmm3, mem(rbx, 5*8)) + vmovsd(xmm5, mem(rbx, 6*8)) + vmovsd(xmm7, mem(rbx, 7*8)) + + add(imm(8*8), rbx) // p += ldp = 8; + + dec(rsi) // i -= 1; + jne(.CKLEFTROWU) // iterate again if i != 0. + + + jmp(.CDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.CCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.CCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.CKITERCOLU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), ymm1) + vmovupd(ymm0, mem(rbx, 0*64+ 0)) + vmovupd(ymm1, mem(rbx, 0*64+32)) + + vmovupd(mem(rax, r10, 1, 0), ymm2) + vmovupd(mem(rax, r10, 1, 32), ymm3) + vmovupd(ymm2, mem(rbx, 1*64+ 0)) + vmovupd(ymm3, mem(rbx, 1*64+32)) + + vmovupd(mem(rax, r10, 2, 0), ymm4) + vmovupd(mem(rax, r10, 2, 32), ymm5) + vmovupd(ymm4, mem(rbx, 2*64+ 0)) + vmovupd(ymm5, mem(rbx, 2*64+32)) + + vmovupd(mem(rax, r13, 1, 0), ymm6) + vmovupd(mem(rax, r13, 1, 32), ymm7) + add(r14, rax) // a += 4*lda; + vmovupd(ymm6, mem(rbx, 3*64+ 0)) + vmovupd(ymm7, mem(rbx, 3*64+32)) + add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; + + dec(rsi) // i -= 1; + jne(.CKITERCOLU) // iterate again if i != 0. + + + + label(.CCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.CDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.CKLEFTCOLU) // EDGE LOOP (k_left) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), ymm1) + add(r10, rax) // a += lda; + vmovupd(ymm0, mem(rbx, 0*64+ 0)) + vmovupd(ymm1, mem(rbx, 0*64+32)) + add(imm(8*8), rbx) // p += ldp = 8; + + dec(rsi) // i -= 1; + jne(.CKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.CDONE) // jump to end. + + + + label(.CDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk ) + { + PASTEMAC(cscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + scomplex* restrict p_edge = p + (i )*1; + + bli_cset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + scomplex* restrict p_edge = p + (j )*ldp; + + bli_cset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c new file mode 100644 index 0000000000..b64f26591d --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c @@ -0,0 +1,401 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( double, d, packm_6xk_haswell_ref ) + +void bli_dpackm_haswell_asm_6xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + double* restrict kappa, + double* restrict a, inc_t inca0, inc_t lda0, + double* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dpackm_6xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 6; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 6; + + // Define a local copy of 1.0 so we can test for unit kappa. + double one_l = 1.0; + double* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 4; +#if 1 + const uint64_t k_left = k0 % 4; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_deq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem(, r8, 8), r8) // inca *= sizeof(double) + lea(mem(, r10, 8), r10) // lda *= sizeof(double) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 4), r14) // r14 = 4*lda + + mov(var(one), rdx) // load address of 1.0 constant + vmovsd(mem(rdx), xmm1) // load 1.0 + + mov(var(kappa), rcx) // load address of kappa + vmovsd(mem(rcx), xmm0) // load kappa + + + // now branch on kappa == 1.0 + + vucomisd(xmm0, xmm1) // set ZF if kappa == 1.0 + je(.DKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.DKAPPANONU) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.DCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.DROWNONU) + + jmp(.DDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.DCOLNONU) + + jmp(.DDONE) // jump to end. + + + + + label(.DKAPPAUNIT) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.DCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.DROWUNIT) + + lea(mem(r8, r8, 2), r12) // r12 = 3*inca + lea(mem(r12, r8, 2), rcx) // rcx = 5*inca + //lea(mem(r12, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DKITERROWU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, r8, 1, 0), ymm2) + vmovupd(mem(rax, r8, 2, 0), ymm4) + vmovupd(mem(rax, r12, 1, 0), ymm6) + + vunpcklpd(ymm2, ymm0, ymm10) + vunpckhpd(ymm2, ymm0, ymm11) + vunpcklpd(ymm6, ymm4, ymm12) + vunpckhpd(ymm6, ymm4, ymm13) + vinsertf128(imm(0x1), xmm12, ymm10, ymm0) + vinsertf128(imm(0x1), xmm13, ymm11, ymm2) + vperm2f128(imm(0x31), ymm12, ymm10, ymm4) + vperm2f128(imm(0x31), ymm13, ymm11, ymm6) + + vmovupd(ymm0, mem(rbx, 0*48)) + vmovupd(ymm2, mem(rbx, 1*48)) + vmovupd(ymm4, mem(rbx, 2*48)) + vmovupd(ymm6, mem(rbx, 3*48)) + + vmovupd(mem(rax, r8, 4, 0), ymm1) + vmovupd(mem(rax, rcx, 1, 0), ymm3) + + add(r14, rax) // a += 4*lda; + + vunpcklpd(ymm3, ymm1, ymm10) + vunpckhpd(ymm3, ymm1, ymm11) + vextractf128(imm(0x1), ymm10, xmm12) + vextractf128(imm(0x1), ymm11, xmm13) + + vmovupd(xmm10, mem(rbx, 0*48+32)) + vmovupd(xmm11, mem(rbx, 1*48+32)) + vmovupd(xmm12, mem(rbx, 2*48+32)) + vmovupd(xmm13, mem(rbx, 3*48+32)) + + add(imm(4*6*8), rbx) // p += 4*ldp = 4*6; + + dec(rsi) // i -= 1; + jne(.DKITERROWU) // iterate again if i != 0. + + + + label(.DCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DKLEFTROWU) // EDGE LOOP (k_left) + + vmovsd(mem(rax, 0), xmm0) + vmovsd(mem(rax, r8, 1, 0), xmm2) + vmovsd(mem(rax, r8, 2, 0), xmm4) + vmovsd(mem(rax, r12, 1, 0), xmm6) + vmovsd(mem(rax, r8, 4, 0), xmm1) + vmovsd(mem(rax, rcx, 1, 0), xmm3) + + add(r10, rax) // a += lda; + + vmovsd(xmm0, mem(rbx, 0*8)) + vmovsd(xmm2, mem(rbx, 1*8)) + vmovsd(xmm4, mem(rbx, 2*8)) + vmovsd(xmm6, mem(rbx, 3*8)) + vmovsd(xmm1, mem(rbx, 4*8)) + vmovsd(xmm3, mem(rbx, 5*8)) + + add(imm(6*8), rbx) // p += ldp = 6; + + dec(rsi) // i -= 1; + jne(.DKLEFTROWU) // iterate again if i != 0. + + + jmp(.DDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.DCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DKITERCOLU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), xmm1) + vmovupd(ymm0, mem(rbx, 0*48+ 0)) + vmovupd(xmm1, mem(rbx, 0*48+32)) + + vmovupd(mem(rax, r10, 1, 0), ymm2) + vmovupd(mem(rax, r10, 1, 32), xmm3) + vmovupd(ymm2, mem(rbx, 1*48+ 0)) + vmovupd(xmm3, mem(rbx, 1*48+32)) + + vmovupd(mem(rax, r10, 2, 0), ymm4) + vmovupd(mem(rax, r10, 2, 32), xmm5) + vmovupd(ymm4, mem(rbx, 2*48+ 0)) + vmovupd(xmm5, mem(rbx, 2*48+32)) + + vmovupd(mem(rax, r13, 1, 0), ymm6) + vmovupd(mem(rax, r13, 1, 32), xmm7) + add(r14, rax) // a += 4*lda; + vmovupd(ymm6, mem(rbx, 3*48+ 0)) + vmovupd(xmm7, mem(rbx, 3*48+32)) + add(imm(4*6*8), rbx) // p += 4*ldp = 4*6; + + dec(rsi) // i -= 1; + jne(.DKITERCOLU) // iterate again if i != 0. + + + + label(.DCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DKLEFTCOLU) // EDGE LOOP (k_left) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), xmm1) + add(r10, rax) // a += lda; + vmovupd(ymm0, mem(rbx, 0*48+ 0)) + vmovupd(xmm1, mem(rbx, 0*48+32)) + add(imm(6*8), rbx) // p += ldp = 6; + + dec(rsi) // i -= 1; + jne(.DKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.DDONE) // jump to end. + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || !unitk ) + { + PASTEMAC(dscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + +//bli_dfprintm( stdout, "packm 6xk ker: a_packed", cdim0, k0_max, p, 1, ldp0, "%5.2f", "" ); + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c new file mode 100644 index 0000000000..9deb564ce4 --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c @@ -0,0 +1,409 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( double, d, packm_8xk_haswell_ref ) + +void bli_dpackm_haswell_asm_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + double* restrict kappa, + double* restrict a, inc_t inca0, inc_t lda0, + double* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dpackm_8xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 8; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 8; + + // Define a local copy of 1.0 so we can test for unit kappa. + double one_l = 1.0; + double* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 4; +#if 1 + const uint64_t k_left = k0 % 4; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_deq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem(, r8, 8), r8) // inca *= sizeof(double) + lea(mem(, r10, 8), r10) // lda *= sizeof(double) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 4), r14) // r14 = 4*lda + + mov(var(one), rdx) // load address of 1.0 constant + vmovsd(mem(rdx), xmm1) // load 1.0 + + mov(var(kappa), rcx) // load address of kappa + vmovsd(mem(rcx), xmm0) // load kappa + + + // now branch on kappa == 1.0 + + vucomisd(xmm0, xmm1) // set ZF if kappa == 1.0 + je(.DKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.DKAPPANONU) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.DCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.DROWNONU) + + jmp(.DDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.DCOLNONU) + + jmp(.DDONE) // jump to end. + + + + + label(.DKAPPAUNIT) + + cmp(imm(8), r8) // set ZF if (8*inca) == 8. + jz(.DCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.DROWUNIT) + + lea(mem(r8, r8, 2), r12) // r12 = 3*inca + lea(mem(r12, r8, 2), rcx) // rcx = 5*inca + lea(mem(r12, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DKITERROWU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, r8, 1, 0), ymm2) + vmovupd(mem(rax, r8, 2, 0), ymm4) + vmovupd(mem(rax, r12, 1, 0), ymm6) + + vunpcklpd(ymm2, ymm0, ymm10) + vunpckhpd(ymm2, ymm0, ymm11) + vunpcklpd(ymm6, ymm4, ymm12) + vunpckhpd(ymm6, ymm4, ymm13) + vinsertf128(imm(0x1), xmm12, ymm10, ymm0) + vinsertf128(imm(0x1), xmm13, ymm11, ymm2) + vperm2f128(imm(0x31), ymm12, ymm10, ymm4) + vperm2f128(imm(0x31), ymm13, ymm11, ymm6) + + vmovupd(ymm0, mem(rbx, 0*64)) + vmovupd(ymm2, mem(rbx, 1*64)) + vmovupd(ymm4, mem(rbx, 2*64)) + vmovupd(ymm6, mem(rbx, 3*64)) + + vmovupd(mem(rax, r8, 4, 0), ymm1) + vmovupd(mem(rax, rcx, 1, 0), ymm3) + vmovupd(mem(rax, r12, 2, 0), ymm5) + vmovupd(mem(rax, rdx, 1, 0), ymm7) + + add(r14, rax) // a += 4*lda; + + vunpcklpd(ymm3, ymm1, ymm10) + vunpckhpd(ymm3, ymm1, ymm11) + vunpcklpd(ymm7, ymm5, ymm12) + vunpckhpd(ymm7, ymm5, ymm13) + vinsertf128(imm(0x1), xmm12, ymm10, ymm1) + vinsertf128(imm(0x1), xmm13, ymm11, ymm3) + vperm2f128(imm(0x31), ymm12, ymm10, ymm5) + vperm2f128(imm(0x31), ymm13, ymm11, ymm7) + + vmovupd(ymm1, mem(rbx, 0*64+32)) + vmovupd(ymm3, mem(rbx, 1*64+32)) + vmovupd(ymm5, mem(rbx, 2*64+32)) + vmovupd(ymm7, mem(rbx, 3*64+32)) + + add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; + + dec(rsi) // i -= 1; + jne(.DKITERROWU) // iterate again if i != 0. + + + + label(.DCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DKLEFTROWU) // EDGE LOOP (k_left) + + vmovsd(mem(rax, 0), xmm0) + vmovsd(mem(rax, r8, 1, 0), xmm2) + vmovsd(mem(rax, r8, 2, 0), xmm4) + vmovsd(mem(rax, r12, 1, 0), xmm6) + vmovsd(mem(rax, r8, 4, 0), xmm1) + vmovsd(mem(rax, rcx, 1, 0), xmm3) + vmovsd(mem(rax, r12, 2, 0), xmm5) + vmovsd(mem(rax, rdx, 1, 0), xmm7) + + add(r10, rax) // a += lda; + + vmovsd(xmm0, mem(rbx, 0*8)) + vmovsd(xmm2, mem(rbx, 1*8)) + vmovsd(xmm4, mem(rbx, 2*8)) + vmovsd(xmm6, mem(rbx, 3*8)) + vmovsd(xmm1, mem(rbx, 4*8)) + vmovsd(xmm3, mem(rbx, 5*8)) + vmovsd(xmm5, mem(rbx, 6*8)) + vmovsd(xmm7, mem(rbx, 7*8)) + + add(imm(8*8), rbx) // p += ldp = 8; + + dec(rsi) // i -= 1; + jne(.DKLEFTROWU) // iterate again if i != 0. + + + jmp(.DDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.DCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DKITERCOLU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), ymm1) + vmovupd(ymm0, mem(rbx, 0*64+ 0)) + vmovupd(ymm1, mem(rbx, 0*64+32)) + + vmovupd(mem(rax, r10, 1, 0), ymm2) + vmovupd(mem(rax, r10, 1, 32), ymm3) + vmovupd(ymm2, mem(rbx, 1*64+ 0)) + vmovupd(ymm3, mem(rbx, 1*64+32)) + + vmovupd(mem(rax, r10, 2, 0), ymm4) + vmovupd(mem(rax, r10, 2, 32), ymm5) + vmovupd(ymm4, mem(rbx, 2*64+ 0)) + vmovupd(ymm5, mem(rbx, 2*64+32)) + + vmovupd(mem(rax, r13, 1, 0), ymm6) + vmovupd(mem(rax, r13, 1, 32), ymm7) + add(r14, rax) // a += 4*lda; + vmovupd(ymm6, mem(rbx, 3*64+ 0)) + vmovupd(ymm7, mem(rbx, 3*64+32)) + add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; + + dec(rsi) // i -= 1; + jne(.DKITERCOLU) // iterate again if i != 0. + + + + label(.DCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DKLEFTCOLU) // EDGE LOOP (k_left) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), ymm1) + add(r10, rax) // a += lda; + vmovupd(ymm0, mem(rbx, 0*64+ 0)) + vmovupd(ymm1, mem(rbx, 0*64+32)) + add(imm(8*8), rbx) // p += ldp = 8; + + dec(rsi) // i -= 1; + jne(.DKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.DDONE) // jump to end. + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || !unitk ) + { + PASTEMAC(dscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c new file mode 100644 index 0000000000..40ac22bc55 --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c @@ -0,0 +1,568 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( double, d, packm_16xk_haswell_ref ) + +void bli_spackm_haswell_asm_16xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + float* restrict kappa, + float* restrict a, inc_t inca0, inc_t lda0, + float* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_spackm_16xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 16; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 8; + + // Define a local copy of 1.0 so we can test for unit kappa. + float one_l = 1.0; + float* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 8; +#if 1 + const uint64_t k_left = k0 % 8; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_seq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem(, r8, 4), r8) // inca *= sizeof(float) + lea(mem(, r10, 4), r10) // lda *= sizeof(float) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 8), r14) // r14 = 8*lda + + mov(var(one), rdx) // load address of 1.0 constant + vmovss(mem(rdx), xmm1) // load 1.0 + + mov(var(kappa), rcx) // load address of kappa + vmovss(mem(rcx), xmm0) // load kappa + + + // now branch on kappa == 1.0 + + vucomiss(xmm0, xmm1) // set ZF if kappa == 1.0 + je(.SKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.SKAPPANONU) + + cmp(imm(4), r8) // set ZF if (4*inca) == 4. + jz(.SCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.SROWNONU) + + jmp(.SDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.SCOLNONU) + + jmp(.SDONE) // jump to end. + + + + + label(.SKAPPAUNIT) + + cmp(imm(4), r8) // set ZF if (4*inca) == 4. + jz(.SCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.SROWUNIT) + + lea(mem(r8, r8, 2), r13) // r13 = 3*inca + lea(mem(r13, r8, 2), r15) // r15 = 5*inca + lea(mem(r13, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SKITERROWU) // MAIN LOOP (k_iter) + + mov(rax, r12) // r12 = rax + mov(rbx, rcx) // rcx = rbx + + + // begin IO on rows 0-3 + vmovups(mem(r12, 0), ymm4) + vmovups(mem(r12, r8, 1, 0), ymm6) + vmovups(mem(r12, r8, 2, 0), ymm8) + vmovups(mem(r12, r13, 1, 0), ymm10) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 0*64)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, 4*64)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 1*64)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, 5*64)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 2*64)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, 6*64)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 3*64)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, 7*64)) // store ( gamma07..gamma37 ) + + + lea(mem(r12, r8, 4), r12) // r12 += 4*inca + add(imm(4*4), rcx) // rcx += 4; + + + // begin IO on rows 4-7 + vmovups(mem(r12, 0), ymm4) + vmovups(mem(r12, r8, 1, 0), ymm6) + vmovups(mem(r12, r8, 2, 0), ymm8) + vmovups(mem(r12, r13, 1, 0), ymm10) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 0*64)) // store ( gamma40..gamma70 ) + vmovups(xmm2, mem(rcx, 4*64)) // store ( gamma44..gamma74 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 1*64)) // store ( gamma41..gamma71 ) + vmovups(xmm2, mem(rcx, 5*64)) // store ( gamma45..gamma75 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 2*64)) // store ( gamma42..gamma72 ) + vmovups(xmm2, mem(rcx, 6*64)) // store ( gamma46..gamma76 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 3*64)) // store ( gamma43..gamma73 ) + vmovups(xmm2, mem(rcx, 7*64)) // store ( gamma47..gamma77 ) + + + lea(mem(r12, r8, 4), r12) // r12 += 4*inca + add(imm(4*4), rcx) // rcx += 4; + + + // begin IO on rows 8-11 + vmovups(mem(r12, 0), ymm4) + vmovups(mem(r12, r8, 1, 0), ymm6) + vmovups(mem(r12, r8, 2, 0), ymm8) + vmovups(mem(r12, r13, 1, 0), ymm10) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 0*64)) // store ( gamma80..gammaB0 ) + vmovups(xmm2, mem(rcx, 4*64)) // store ( gamma84..gammaB4 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 1*64)) // store ( gamma81..gammaB1 ) + vmovups(xmm2, mem(rcx, 5*64)) // store ( gamma85..gammaB5 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 2*64)) // store ( gamma82..gammaB2 ) + vmovups(xmm2, mem(rcx, 6*64)) // store ( gamma86..gammaB6 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 3*64)) // store ( gamma83..gammaB3 ) + vmovups(xmm2, mem(rcx, 7*64)) // store ( gamma87..gammaB7 ) + + + lea(mem(r12, r8, 4), r12) // r12 += 4*inca + add(imm(4*4), rcx) // rcx += 4; + + + // begin IO on rows 12-15 + vmovups(mem(r12, 0), ymm4) + vmovups(mem(r12, r8, 1, 0), ymm6) + vmovups(mem(r12, r8, 2, 0), ymm8) + vmovups(mem(r12, r13, 1, 0), ymm10) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 0*64)) // store ( gammaC0..gammaF0 ) + vmovups(xmm2, mem(rcx, 4*64)) // store ( gammaC4..gammaF4 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 1*64)) // store ( gammaC1..gammaF1 ) + vmovups(xmm2, mem(rcx, 5*64)) // store ( gammaC5..gammaF5 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, 2*64)) // store ( gammaC2..gammaF2 ) + vmovups(xmm2, mem(rcx, 6*64)) // store ( gammaC6..gammaF6 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, 3*64)) // store ( gammaC3..gammaF3 ) + vmovups(xmm2, mem(rcx, 7*64)) // store ( gammaC7..gammaF7 ) + + + add(r14, rax) // a += 8*lda; + add(imm(8*16*4), rbx) // p += 8*ldp = 8*16; + + dec(rsi) // i -= 1; + jne(.SKITERROWU) // iterate again if i != 0. + + + + label(.SCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SKLEFTROWU) // EDGE LOOP (k_left) + + vmovss(mem(rax, 0), xmm0) + vmovss(mem(rax, r8, 1, 0), xmm2) + vmovss(mem(rax, r8, 2, 0), xmm4) + vmovss(mem(rax, r13, 1, 0), xmm6) + vmovss(mem(rax, r8, 4, 0), xmm1) + vmovss(mem(rax, r15, 1, 0), xmm3) + vmovss(mem(rax, r13, 2, 0), xmm5) + vmovss(mem(rax, rdx, 1, 0), xmm7) + + vmovss(xmm0, mem(rbx, 0*4)) + vmovss(xmm2, mem(rbx, 1*4)) + vmovss(xmm4, mem(rbx, 2*4)) + vmovss(xmm6, mem(rbx, 3*4)) + vmovss(xmm1, mem(rbx, 4*4)) + vmovss(xmm3, mem(rbx, 5*4)) + vmovss(xmm5, mem(rbx, 6*4)) + vmovss(xmm7, mem(rbx, 7*4)) + + lea(mem(rax, r8, 8), r12) // r12 = a + 8*inca + + vmovss(mem(r12, 0), xmm0) + vmovss(mem(r12, r8, 1, 0), xmm2) + vmovss(mem(r12, r8, 2, 0), xmm4) + vmovss(mem(r12, r13, 1, 0), xmm6) + vmovss(mem(r12, r8, 4, 0), xmm1) + vmovss(mem(r12, r15, 1, 0), xmm3) + vmovss(mem(r12, r13, 2, 0), xmm5) + vmovss(mem(r12, rdx, 1, 0), xmm7) + + add(r10, rax) // a += lda; + + vmovss(xmm0, mem(rbx, 8*4)) + vmovss(xmm2, mem(rbx, 9*4)) + vmovss(xmm4, mem(rbx, 10*4)) + vmovss(xmm6, mem(rbx, 11*4)) + vmovss(xmm1, mem(rbx, 12*4)) + vmovss(xmm3, mem(rbx, 13*4)) + vmovss(xmm5, mem(rbx, 14*4)) + vmovss(xmm7, mem(rbx, 15*4)) + + add(imm(16*4), rbx) // p += ldp = 16; + + dec(rsi) // i -= 1; + jne(.SKLEFTROWU) // iterate again if i != 0. + + + jmp(.SDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.SCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + lea(mem(r13, r10, 2), r15) // r15 = 5*lda + lea(mem(r13, r10, 4), rdx) // rdx = 7*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SKITERCOLU) // MAIN LOOP (k_iter) + + vmovups(mem(rax, 0), ymm0) + vmovups(mem(rax, 32), ymm1) + vmovups(ymm0, mem(rbx, 0*64+ 0)) + vmovups(ymm1, mem(rbx, 0*64+32)) + + vmovups(mem(rax, r10, 1, 0), ymm2) + vmovups(mem(rax, r10, 1, 32), ymm3) + vmovups(ymm2, mem(rbx, 1*64+ 0)) + vmovups(ymm3, mem(rbx, 1*64+32)) + + vmovups(mem(rax, r10, 2, 0), ymm4) + vmovups(mem(rax, r10, 2, 32), ymm5) + vmovups(ymm4, mem(rbx, 2*64+ 0)) + vmovups(ymm5, mem(rbx, 2*64+32)) + + vmovups(mem(rax, r13, 1, 0), ymm6) + vmovups(mem(rax, r13, 1, 32), ymm7) + vmovups(ymm6, mem(rbx, 3*64+ 0)) + vmovups(ymm7, mem(rbx, 3*64+32)) + + vmovups(mem(rax, r10, 4, 0), ymm8) + vmovups(mem(rax, r10, 4, 32), ymm9) + vmovups(ymm8, mem(rbx, 4*64+ 0)) + vmovups(ymm9, mem(rbx, 4*64+32)) + + vmovups(mem(rax, r15, 1, 0), ymm10) + vmovups(mem(rax, r15, 1, 32), ymm11) + vmovups(ymm10, mem(rbx, 5*64+ 0)) + vmovups(ymm11, mem(rbx, 5*64+32)) + + vmovups(mem(rax, r13, 2, 0), ymm12) + vmovups(mem(rax, r13, 2, 32), ymm13) + vmovups(ymm12, mem(rbx, 6*64+ 0)) + vmovups(ymm13, mem(rbx, 6*64+32)) + + vmovups(mem(rax, rdx, 1, 0), ymm14) + vmovups(mem(rax, rdx, 1, 32), ymm15) + add(r14, rax) // a += 8*lda; + vmovups(ymm14, mem(rbx, 7*64+ 0)) + vmovups(ymm15, mem(rbx, 7*64+32)) + add(imm(8*16*4), rbx) // p += 8*ldp = 8*16; + + dec(rsi) // i -= 1; + jne(.SKITERCOLU) // iterate again if i != 0. + + + + label(.SCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SKLEFTCOLU) // EDGE LOOP (k_left) + + vmovups(mem(rax, 0), ymm0) + vmovups(mem(rax, 32), ymm1) + add(r10, rax) // a += lda; + vmovups(ymm0, mem(rbx, 0*64+ 0)) + vmovups(ymm1, mem(rbx, 0*64+32)) + add(imm(16*4), rbx) // p += ldp = 16; + + dec(rsi) // i -= 1; + jne(.SKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.SDONE) // jump to end. + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || !unitk ) + { + PASTEMAC(sscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + float* restrict p_edge = p + (i )*1; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + float* restrict p_edge = p + (j )*ldp; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c new file mode 100644 index 0000000000..3a134bed8f --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c @@ -0,0 +1,441 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( double, d, packm_6xk_haswell_ref ) + +void bli_spackm_haswell_asm_6xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + float* restrict kappa, + float* restrict a, inc_t inca0, inc_t lda0, + float* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_spackm_6xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 6; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 8; + + // Define a local copy of 1.0 so we can test for unit kappa. + float one_l = 1.0; + float* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 8; +#if 1 + const uint64_t k_left = k0 % 8; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_seq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem(, r8, 4), r8) // inca *= sizeof(float) + lea(mem(, r10, 4), r10) // lda *= sizeof(float) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 8), r14) // r14 = 8*lda + + mov(var(one), rdx) // load address of 1.0 constant + vmovss(mem(rdx), xmm1) // load 1.0 + + mov(var(kappa), rcx) // load address of kappa + vmovss(mem(rcx), xmm0) // load kappa + + + // now branch on kappa == 1.0 + + vucomiss(xmm0, xmm1) // set ZF if kappa == 1.0 + je(.SKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.SKAPPANONU) + + cmp(imm(4), r8) // set ZF if (4*inca) == 4. + jz(.SCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.SROWNONU) + + jmp(.SDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.SCOLNONU) + + jmp(.SDONE) // jump to end. + + + + + label(.SKAPPAUNIT) + + cmp(imm(4), r8) // set ZF if (4*inca) == 4. + jz(.SCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.SROWUNIT) + + lea(mem(r8, r8, 2), r13) // r13 = 3*inca + lea(mem(r13, r8, 2), r15) // r15 = 5*inca + //lea(mem(r13, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SKITERROWU) // MAIN LOOP (k_iter) + + // begin IO on rows 0-3 + vmovups(mem(rax, 0), ymm4) + vmovups(mem(rax, r8, 1, 0), ymm6) + vmovups(mem(rax, r8, 2, 0), ymm8) + vmovups(mem(rax, r13, 1, 0), ymm10) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rbx, 0*24)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rbx, 4*24)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rbx, 1*24)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rbx, 5*24)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rbx, 2*24)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rbx, 6*24)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rbx, 3*24)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rbx, 7*24)) // store ( gamma07..gamma37 ) + + // begin IO on rows 4-5 + vmovups(mem(rax, r8, 4, 0), ymm12) + vmovups(mem(rax, r15, 1, 0), ymm14) + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rbx, 0*24+16)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rbx, 1*24+16)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rbx, 4*24+16)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rbx, 5*24+16)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rbx, 2*24+16)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rbx, 3*24+16)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rbx, 6*24+16)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rbx, 7*24+16)) // store ( gamma47..gamma57 ) + + + add(r14, rax) // a += 8*lda; + add(imm(8*6*4), rbx) // p += 8*ldp = 8*6; + + dec(rsi) // i -= 1; + jne(.SKITERROWU) // iterate again if i != 0. + + + + label(.SCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SKLEFTROWU) // EDGE LOOP (k_left) + + vmovss(mem(rax, 0), xmm0) + vmovss(mem(rax, r8, 1, 0), xmm2) + vmovss(mem(rax, r8, 2, 0), xmm4) + vmovss(mem(rax, r13, 1, 0), xmm6) + vmovss(mem(rax, r8, 4, 0), xmm1) + vmovss(mem(rax, r15, 1, 0), xmm3) + + vmovss(xmm0, mem(rbx, 0*4)) + vmovss(xmm2, mem(rbx, 1*4)) + vmovss(xmm4, mem(rbx, 2*4)) + vmovss(xmm6, mem(rbx, 3*4)) + vmovss(xmm1, mem(rbx, 4*4)) + vmovss(xmm3, mem(rbx, 5*4)) + + add(r10, rax) // a += lda; + add(imm(6*4), rbx) // p += ldp = 6; + + dec(rsi) // i -= 1; + jne(.SKLEFTROWU) // iterate again if i != 0. + + + jmp(.SDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.SCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + lea(mem(r13, r10, 2), r15) // r15 = 5*lda + lea(mem(r13, r10, 4), rdx) // rdx = 7*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SKITERCOLU) // MAIN LOOP (k_iter) + + vmovups(mem(rax, 0), xmm0) + vmovsd( mem(rax, 16), xmm1) + vmovups(xmm0, mem(rbx, 0*24+ 0)) + vmovsd( xmm1, mem(rbx, 0*24+16)) + + vmovups(mem(rax, r10, 1, 0), xmm2) + vmovsd( mem(rax, r10, 1, 16), xmm3) + vmovups(xmm2, mem(rbx, 1*24+ 0)) + vmovsd( xmm3, mem(rbx, 1*24+16)) + + vmovups(mem(rax, r10, 2, 0), xmm4) + vmovsd( mem(rax, r10, 2, 16), xmm5) + vmovups(xmm4, mem(rbx, 2*24+ 0)) + vmovsd( xmm5, mem(rbx, 2*24+16)) + + vmovups(mem(rax, r13, 1, 0), xmm6) + vmovsd( mem(rax, r13, 1, 16), xmm7) + vmovups(xmm6, mem(rbx, 3*24+ 0)) + vmovsd( xmm7, mem(rbx, 3*24+16)) + + vmovups(mem(rax, r10, 4, 0), xmm8) + vmovsd( mem(rax, r10, 4, 16), xmm9) + vmovups(xmm8, mem(rbx, 4*24+ 0)) + vmovsd( xmm9, mem(rbx, 4*24+16)) + + vmovups(mem(rax, r15, 1, 0), xmm10) + vmovsd( mem(rax, r15, 1, 16), xmm11) + vmovups(xmm10, mem(rbx, 5*24+ 0)) + vmovsd( xmm11, mem(rbx, 5*24+16)) + + vmovups(mem(rax, r13, 2, 0), xmm12) + vmovsd( mem(rax, r13, 2, 16), xmm13) + vmovups(xmm12, mem(rbx, 6*24+ 0)) + vmovsd( xmm13, mem(rbx, 6*24+16)) + + vmovups(mem(rax, rdx, 1, 0), xmm14) + vmovsd( mem(rax, rdx, 1, 16), xmm15) + vmovups(xmm14, mem(rbx, 7*24+ 0)) + vmovsd( xmm15, mem(rbx, 7*24+16)) + + add(r14, rax) // a += 8*lda; + add(imm(8*6*4), rbx) // p += 8*ldp = 8*6; + + dec(rsi) // i -= 1; + jne(.SKITERCOLU) // iterate again if i != 0. + + + + label(.SCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SKLEFTCOLU) // EDGE LOOP (k_left) + + vmovups(mem(rax, 0), xmm0) + vmovsd( mem(rax, 16), xmm1) + add(r10, rax) // a += lda; + vmovups(xmm0, mem(rbx, 0*24+ 0)) + vmovsd( xmm1, mem(rbx, 0*24+16)) + add(imm(6*4), rbx) // p += ldp = 6; + + dec(rsi) // i -= 1; + jne(.SKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.SDONE) // jump to end. + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || !unitk ) + { + PASTEMAC(sscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + float* restrict p_edge = p + (i )*1; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + float* restrict p_edge = p + (j )*ldp; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c new file mode 100644 index 0000000000..1a714abe26 --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c @@ -0,0 +1,401 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( dcomplex, z, packm_3xk_haswell_ref ) + +void bli_zpackm_haswell_asm_3xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + dcomplex* restrict kappa, + dcomplex* restrict a, inc_t inca0, inc_t lda0, + dcomplex* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_zpackm_3xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 3; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 8; + + // Define a local copy of 1.0 so we can test for unit kappa. + double one_l = 1.0; + double* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 4; +#if 1 + const uint64_t k_left = k0 % 4; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_zeq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && !conja && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem( , r8, 2), r8) + lea(mem( , r8, 8), r8) // inca *= sizeof(dcomplex) + lea(mem( , r10, 2), r10) + lea(mem( , r10, 8), r10) // lda *= sizeof(dcomplex) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 4), r14) // r14 = 4*lda + + mov(var(one), rdx) // load address of 1.0 constant + vbroadcastsd(mem(rdx, 0), ymm1) // load 1.0 and duplicate + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to 0.0. + + mov(var(kappa), rcx) // load address of kappa + vbroadcastsd(mem(rcx, 0), ymm10) // load kappa_r and duplicate + vbroadcastsd(mem(rcx, 8), ymm11) // load kappa_i and duplicate + + + // now branch on kappa == 1.0 + + vucomisd(xmm1, xmm10) // set ZF if kappa_r == 1.0. + sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm11) // set ZF if kappa_i == 0.0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + and(r12b, r13b) // set ZF if r12b & r13b == 1. + jne(.ZKAPPAUNIT) // if ZF = 1, jump to beta == 0 case + + + + label(.ZKAPPANONU) + + cmp(imm(16), r8) // set ZF if (16*inca) == 16. + jz(.ZCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.ZROWNONU) + + jmp(.ZDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.ZCOLNONU) + + jmp(.ZDONE) // jump to end. + + + + + label(.ZKAPPAUNIT) + + cmp(imm(16), r8) // set ZF if (16*inca) == 16. + jz(.ZCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.ZROWUNIT) + + //lea(mem(r8, r8, 2), r12) // r12 = 3*inca + //lea(mem(r12, r8, 2), rcx) // rcx = 5*inca + //lea(mem(r12, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.ZKITERROWU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm8) + vmovupd(mem(rax, r8, 1, 0), ymm10) + vmovupd(mem(rax, r8, 2, 0), ymm12) + + vextractf128(imm(0x1), ymm8, xmm9) + vextractf128(imm(0x1), ymm10, xmm11) + vextractf128(imm(0x1), ymm12, xmm13) + + vmovupd(xmm8, mem(rbx, 0*16+0*48)) + vmovupd(xmm10, mem(rbx, 1*16+0*48)) + vmovupd(xmm12, mem(rbx, 2*16+0*48)) + + vmovupd(xmm9, mem(rbx, 0*16+1*48)) + vmovupd(xmm11, mem(rbx, 1*16+1*48)) + vmovupd(xmm13, mem(rbx, 2*16+1*48)) + + vmovupd(mem(rax, 32), ymm8) + vmovupd(mem(rax, r8, 1, 32), ymm10) + vmovupd(mem(rax, r8, 2, 32), ymm12) + + add(r14, rax) // a += 4*lda; + + vextractf128(imm(0x1), ymm8, xmm9) + vextractf128(imm(0x1), ymm10, xmm11) + vextractf128(imm(0x1), ymm12, xmm13) + + vmovupd(xmm8, mem(rbx, 0*16+2*48)) + vmovupd(xmm10, mem(rbx, 1*16+2*48)) + vmovupd(xmm12, mem(rbx, 2*16+2*48)) + + vmovupd(xmm9, mem(rbx, 0*16+3*48)) + vmovupd(xmm11, mem(rbx, 1*16+3*48)) + vmovupd(xmm13, mem(rbx, 2*16+3*48)) + + add(imm(4*3*16), rbx) // p += 4*ldp = 4*3; + + dec(rsi) // i -= 1; + jne(.ZKITERROWU) // iterate again if i != 0. + + + + label(.ZCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.ZKLEFTROWU) // EDGE LOOP (k_left) + + vmovups(mem(rax, 0), xmm0) + vmovups(mem(rax, r8, 1, 0), xmm2) + vmovups(mem(rax, r8, 2, 0), xmm4) + + add(r10, rax) // a += lda; + + vmovups(xmm0, mem(rbx, 0*16+0*48)) + vmovups(xmm2, mem(rbx, 1*16+0*48)) + vmovups(xmm4, mem(rbx, 2*16+0*48)) + + add(imm(3*16), rbx) // p += ldp = 4; + + dec(rsi) // i -= 1; + jne(.ZKLEFTROWU) // iterate again if i != 0. + + + jmp(.ZDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.ZCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.ZKITERCOLU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), xmm1) + vmovupd(ymm0, mem(rbx, 0*48+ 0)) + vmovupd(xmm1, mem(rbx, 0*48+32)) + + vmovupd(mem(rax, r10, 1, 0), ymm2) + vmovupd(mem(rax, r10, 1, 32), xmm3) + vmovupd(ymm2, mem(rbx, 1*48+ 0)) + vmovupd(xmm3, mem(rbx, 1*48+32)) + + vmovupd(mem(rax, r10, 2, 0), ymm4) + vmovupd(mem(rax, r10, 2, 32), xmm5) + vmovupd(ymm4, mem(rbx, 2*48+ 0)) + vmovupd(xmm5, mem(rbx, 2*48+32)) + + vmovupd(mem(rax, r13, 1, 0), ymm6) + vmovupd(mem(rax, r13, 1, 32), xmm7) + add(r14, rax) // a += 4*lda; + vmovupd(ymm6, mem(rbx, 3*48+ 0)) + vmovupd(xmm7, mem(rbx, 3*48+32)) + add(imm(4*3*16), rbx) // p += 4*ldp = 4*3; + + dec(rsi) // i -= 1; + jne(.ZKITERCOLU) // iterate again if i != 0. + + + + label(.ZCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.ZKLEFTCOLU) // EDGE LOOP (k_left) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), xmm1) + add(r10, rax) // a += lda; + vmovupd(ymm0, mem(rbx, 0*48+ 0)) + vmovupd(xmm1, mem(rbx, 0*48+32)) + add(imm(3*16), rbx) // p += ldp = 3; + + dec(rsi) // i -= 1; + jne(.ZKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.ZDONE) // jump to end. + + + + label(.ZDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk ) + { + PASTEMAC(zscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + dcomplex* restrict p_edge = p + (i )*1; + + bli_zset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + dcomplex* restrict p_edge = p + (j )*ldp; + + bli_zset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c new file mode 100644 index 0000000000..4e11872afb --- /dev/null +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c @@ -0,0 +1,411 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// Prototype reference packm kernels. +PACKM_KER_PROT( dcomplex, z, packm_4xk_haswell_ref ) + +void bli_zpackm_haswell_asm_4xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + dcomplex* restrict kappa, + dcomplex* restrict a, inc_t inca0, inc_t lda0, + dcomplex* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_zpackm_4xk_haswell_ref + ( + conja, schema, cdim0, k0, k0_max, + kappa, a, inca0, lda0, p, ldp0, cntx + ); + return; +#endif + + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 4; + + // This is the "packing" dimension assumed by the packm kernel. + // This should be equal to ldp. + //const dim_t packmnr = 8; + + // Define a local copy of 1.0 so we can test for unit kappa. + double one_l = 1.0; + double* restrict one = &one_l; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + const uint64_t k_iter = k0 / 4; +#if 1 + const uint64_t k_left = k0 % 4; +#else + const uint64_t k_left = k0; +#endif + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_zeq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs && !conja && unitk ) + { + begin_asm() + + mov(var(a), rax) // load address of a. + + mov(var(inca), r8) // load inca + mov(var(lda), r10) // load lda + lea(mem( , r8, 2), r8) + lea(mem( , r8, 8), r8) // inca *= sizeof(dcomplex) + lea(mem( , r10, 2), r10) + lea(mem( , r10, 8), r10) // lda *= sizeof(dcomplex) + + mov(var(p), rbx) // load address of p. + + lea(mem( , r10, 4), r14) // r14 = 4*lda + + mov(var(one), rdx) // load address of 1.0 constant + vbroadcastsd(mem(rdx, 0), ymm1) // load 1.0 and duplicate + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to 0.0. + + mov(var(kappa), rcx) // load address of kappa + vbroadcastsd(mem(rcx, 0), ymm10) // load kappa_r and duplicate + vbroadcastsd(mem(rcx, 8), ymm11) // load kappa_i and duplicate + + + // now branch on kappa == 1.0 + + vucomisd(xmm1, xmm10) // set ZF if kappa_r == 1.0. + sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm11) // set ZF if kappa_i == 0.0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + and(r12b, r13b) // set ZF if r12b & r13b == 1. + jne(.ZKAPPAUNIT) // if ZF = 1, jump to kappa == 1.0 case + + + + label(.ZKAPPANONU) + + cmp(imm(16), r8) // set ZF if (16*inca) == 16. + jz(.ZCOLNONU) // jump to column storage case + + // -- kappa non-unit, row storage on A ------------------------------------- + + label(.ZROWNONU) + + jmp(.ZDONE) // jump to end. + + + // -- kappa non-unit, column storage on A ---------------------------------- + + label(.ZCOLNONU) + + jmp(.ZDONE) // jump to end. + + + + + label(.ZKAPPAUNIT) + + cmp(imm(16), r8) // set ZF if (16*inca) == 16. + jz(.ZCOLUNIT) // jump to column storage case + + + // -- kappa unit, row storage on A ----------------------------------------- + + label(.ZROWUNIT) + + lea(mem(r8, r8, 2), r12) // r12 = 3*inca + //lea(mem(r12, r8, 2), rcx) // rcx = 5*inca + //lea(mem(r12, r8, 4), rdx) // rdx = 7*inca + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONKLEFTROWU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.ZKITERROWU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm8) + vmovupd(mem(rax, r8, 1, 0), ymm10) + vmovupd(mem(rax, r8, 2, 0), ymm12) + vmovupd(mem(rax, r12, 1, 0), ymm14) + + vextractf128(imm(0x1), ymm8, xmm9) + vextractf128(imm(0x1), ymm10, xmm11) + vextractf128(imm(0x1), ymm12, xmm13) + vextractf128(imm(0x1), ymm14, xmm15) + + vmovupd(xmm8, mem(rbx, 0*16+0*64)) + vmovupd(xmm10, mem(rbx, 1*16+0*64)) + vmovupd(xmm12, mem(rbx, 2*16+0*64)) + vmovupd(xmm14, mem(rbx, 3*16+0*64)) + + vmovupd(xmm9, mem(rbx, 0*16+1*64)) + vmovupd(xmm11, mem(rbx, 1*16+1*64)) + vmovupd(xmm13, mem(rbx, 2*16+1*64)) + vmovupd(xmm15, mem(rbx, 3*16+1*64)) + + vmovupd(mem(rax, 32), ymm8) + vmovupd(mem(rax, r8, 1, 32), ymm10) + vmovupd(mem(rax, r8, 2, 32), ymm12) + vmovupd(mem(rax, r12, 1, 32), ymm14) + + add(r14, rax) // a += 4*lda; + + vextractf128(imm(0x1), ymm8, xmm9) + vextractf128(imm(0x1), ymm10, xmm11) + vextractf128(imm(0x1), ymm12, xmm13) + vextractf128(imm(0x1), ymm14, xmm15) + + vmovupd(xmm8, mem(rbx, 0*16+2*64)) + vmovupd(xmm10, mem(rbx, 1*16+2*64)) + vmovupd(xmm12, mem(rbx, 2*16+2*64)) + vmovupd(xmm14, mem(rbx, 3*16+2*64)) + + vmovupd(xmm9, mem(rbx, 0*16+3*64)) + vmovupd(xmm11, mem(rbx, 1*16+3*64)) + vmovupd(xmm13, mem(rbx, 2*16+3*64)) + vmovupd(xmm15, mem(rbx, 3*16+3*64)) + + add(imm(4*4*16), rbx) // p += 4*ldp = 4*4; + + dec(rsi) // i -= 1; + jne(.ZKITERROWU) // iterate again if i != 0. + + + + label(.ZCONKLEFTROWU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.ZKLEFTROWU) // EDGE LOOP (k_left) + + vmovups(mem(rax, 0), xmm0) + vmovups(mem(rax, r8, 1, 0), xmm2) + vmovups(mem(rax, r8, 2, 0), xmm4) + vmovups(mem(rax, r12, 1, 0), xmm6) + + add(r10, rax) // a += lda; + + vmovups(xmm0, mem(rbx, 0*16+0*64)) + vmovups(xmm2, mem(rbx, 1*16+0*64)) + vmovups(xmm4, mem(rbx, 2*16+0*64)) + vmovups(xmm6, mem(rbx, 3*16+0*64)) + + add(imm(4*16), rbx) // p += ldp = 4; + + dec(rsi) // i -= 1; + jne(.ZKLEFTROWU) // iterate again if i != 0. + + + jmp(.ZDONE) // jump to end. + + + // -- kappa unit, column storage on A -------------------------------------- + + label(.ZCOLUNIT) + + lea(mem(r10, r10, 2), r13) // r13 = 3*lda + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONKLEFTCOLU) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.ZKITERCOLU) // MAIN LOOP (k_iter) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), ymm1) + vmovupd(ymm0, mem(rbx, 0*64+ 0)) + vmovupd(ymm1, mem(rbx, 0*64+32)) + + vmovupd(mem(rax, r10, 1, 0), ymm2) + vmovupd(mem(rax, r10, 1, 32), ymm3) + vmovupd(ymm2, mem(rbx, 1*64+ 0)) + vmovupd(ymm3, mem(rbx, 1*64+32)) + + vmovupd(mem(rax, r10, 2, 0), ymm4) + vmovupd(mem(rax, r10, 2, 32), ymm5) + vmovupd(ymm4, mem(rbx, 2*64+ 0)) + vmovupd(ymm5, mem(rbx, 2*64+32)) + + vmovupd(mem(rax, r13, 1, 0), ymm6) + vmovupd(mem(rax, r13, 1, 32), ymm7) + add(r14, rax) // a += 4*lda; + vmovupd(ymm6, mem(rbx, 3*64+ 0)) + vmovupd(ymm7, mem(rbx, 3*64+32)) + add(imm(4*4*16), rbx) // p += 4*ldp = 4*4; + + dec(rsi) // i -= 1; + jne(.ZKITERCOLU) // iterate again if i != 0. + + + + label(.ZCONKLEFTCOLU) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZDONE) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.ZKLEFTCOLU) // EDGE LOOP (k_left) + + vmovupd(mem(rax, 0), ymm0) + vmovupd(mem(rax, 32), ymm1) + add(r10, rax) // a += lda; + vmovupd(ymm0, mem(rbx, 0*64+ 0)) + vmovupd(ymm1, mem(rbx, 0*64+32)) + add(imm(4*16), rbx) // p += ldp = 4; + + dec(rsi) // i -= 1; + jne(.ZKLEFTCOLU) // iterate again if i != 0. + + + //jmp(.ZDONE) // jump to end. + + + + label(.ZDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [inca] "m" (inca), + [lda] "m" (lda), + [p] "m" (p), + [ldp] "m" (ldp), + [kappa] "m" (kappa), + [one] "m" (one) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } + else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk ) + { + PASTEMAC(zscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + dcomplex* restrict p_edge = p + (i )*1; + + bli_zset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + dcomplex* restrict p_edge = p + (j )*ldp; + + bli_zset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 48b0394bde..70ea4ccd7b 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -79,7 +79,9 @@ void bli_sgemm_haswell_asm_6x16 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -94,221 +96,252 @@ void bli_sgemm_haswell_asm_6x16 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_AMBI( s, 6, 16, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - - - + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPREFETCHDONE) + + label(.SCOLPREFETCH) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + lea(mem(rcx, rsi, 8), r14) // r14 = c + 8*cs_c; + lea(mem(r14, r13, 1), rdx) // rdx = c + 11*cs_c; + prefetch(0, mem(r14, 7*8)) // prefetch c + 8*cs_c + prefetch(0, mem(r14, rsi, 1, 7*8)) // prefetch c + 9*cs_c + prefetch(0, mem(r14, rsi, 2, 7*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 12*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 15*cs_c + + label(.SPREFETCHDONE) + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP - - - // iteration 0 + + + // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - - // iteration 1 + + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - - // iteration 2 + + // iteration 2 prefetch(0, mem(rax, 76*4)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - - // iteration 3 + + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*6*4), rax) // a += 4*6 (unroll x mr) add(imm(4*16*4), rbx) // b += 4*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*6*4), rax) // a += 1*6 (unroll x mr) add(imm(1*16*4), rbx) // b += 1*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - + vmulps(ymm0, ymm4, ymm4) // scale by alpha vmulps(ymm0, ymm5, ymm5) vmulps(ymm0, ymm6, ymm6) @@ -321,572 +354,371 @@ void bli_sgemm_haswell_asm_6x16 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) - + lea(mem(rcx, rsi, 8), rdx) // load address of c + 8*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; - + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - // now avoid loading C if beta == 0 - + + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SROWSTORED) // jump to row storage case - - - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm4, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm6, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm8, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm10, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm12, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm14, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - mov(rdx, rcx) // rcx = c + 8*cs_c - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm5, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm7, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm9, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm11, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm13, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm15, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SROWSTORED) - - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm5) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm7) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm9) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm11) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm13) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm15) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vbroadcastss(mem(rbx), ymm3) - - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - jmp(.SDONE) // jump to end. - - - + + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SROWSTORBZ) // jump to row storage case - - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(ymm4, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm6, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm8, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm10, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm12, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm14, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - mov(rdx, rcx) // rcx = c + 8*cs_c - - - vmovaps(ymm5, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm7, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm9, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm11, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm13, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm15, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SROWSTORBZ) - - - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - - + + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_ + label(.SDONE) - - - end_asm( + + vzeroupper() + + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -894,6 +726,8 @@ void bli_sgemm_haswell_asm_6x16 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } @@ -925,7 +759,9 @@ void bli_sgemm_haswell_asm_6x16 void bli_dgemm_haswell_asm_6x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -940,221 +776,245 @@ void bli_dgemm_haswell_asm_6x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_AMBI( d, 6, 8, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - - - + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPREFETCHDONE) + + label(.DCOLPREFETCH) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + + label(.DPREFETCHDONE) + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP - - - // iteration 0 + + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - - // iteration 1 + + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - - // iteration 2 - prefetch(0, mem(rax, 76*8)) - + + // iteration 2 + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - - // iteration 3 + + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -1167,465 +1027,273 @@ void bli_dgemm_haswell_asm_6x8 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; - + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - // now avoid loading C if beta == 0 - + + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DROWSTORED) // jump to row storage case - - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm4, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm6, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm8, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm10, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm12, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm14, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - mov(rdx, rcx) // rcx = c + 4*cs_c - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm5, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm7, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm9, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm11, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm13, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm15, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - - jmp(.DDONE) // jump to end. - - - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm5) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm11) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm5) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) - - - - jmp(.DDONE) // jump to end. - - - + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + lea(mem(r14, rsi, 4), r14) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + //lea(mem(r14, rsi, 4), r14) + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DROWSTORBZ) // jump to row storage case - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovapd(ymm4, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm6, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm8, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm10, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm12, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm14, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - mov(rdx, rcx) // rcx = c + 4*cs_c - - - vmovapd(ymm5, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm7, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm9, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm11, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm13, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm15, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - - jmp(.DDONE) // jump to end. - - - - label(.DROWSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) - - - + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + lea(mem(r14, rsi, 4), r14) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + //lea(mem(r14, rsi, 4), r14) + label(.DDONE) - - - end_asm( + + vzeroupper() + + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1633,45 +1301,26 @@ void bli_dgemm_haswell_asm_6x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) -} + GEMM_UKR_FLUSH_CT( d ); +} -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define CGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovlpd(mem(rcx), xmm0, xmm0) \ - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ - vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ - vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ - vpermilps(imm(0xb1), ymm0, ymm3) \ - vmulps(ymm1, ymm0, ymm0) \ - vmulps(ymm2, ymm3, ymm3) \ - vaddsubps(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define CGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovlpd(xmm0, mem(rcx)) \ - vmovhpd(xmm0, mem(rcx, rsi, 1)) \ - vmovlpd(xmm3, mem(rcx, rsi, 2)) \ - vmovhpd(xmm3, mem(rcx, r13, 1)) -#define CGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ +#define CGEMM_INPUT_SCALE_RS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilps(imm(0xb1), ymm0, ymm3) \ vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) - -#define CGEMM_OUTPUT_RS \ - vmovups(ymm0, mem(rcx)) \ void bli_cgemm_haswell_asm_3x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1686,291 +1335,284 @@ void bli_cgemm_haswell_asm_3x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 3, 8, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(scomplex) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*rs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*rs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - + // contains the k_left loop. + + label(.CLOOPKITER) // MAIN LOOP - - - // iteration 0 + + + // iteration 0 prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - - // iteration 1 + + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - - // iteration 2 + + // iteration 2 prefetch(0, mem(rax, 38*8)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - - // iteration 3 + + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*3*8), rax) // a += 4*3 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - + // else, we prepare to enter k_left loop. + + label(.CLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*3*8), rax) // a += 1*3 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - - - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 + + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) vpermilps(imm(0xb1), ymm7, ymm7) vpermilps(imm(0xb1), ymm10, ymm10) vpermilps(imm(0xb1), ymm11, ymm11) vpermilps(imm(0xb1), ymm14, ymm14) vpermilps(imm(0xb1), ymm15, ymm15) - - - // subtract/add even/odd elements + + + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) - + vaddsubps(ymm10, ymm8, ymm8) vaddsubps(ymm11, ymm9, ymm9) - + vaddsubps(ymm14, ymm12, ymm12) vaddsubps(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate - - + + vpermilps(imm(0xb1), ymm4, ymm3) vmulps(ymm0, ymm4, ymm4) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm4, ymm4) - + vpermilps(imm(0xb1), ymm5, ymm3) vmulps(ymm0, ymm5, ymm5) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm5, ymm5) - - + + vpermilps(imm(0xb1), ymm8, ymm3) vmulps(ymm0, ymm8, ymm8) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm9, ymm3) vmulps(ymm0, ymm9, ymm9) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm9, ymm9) - - + + vpermilps(imm(0xb1), ymm12, ymm3) vmulps(ymm0, ymm12, ymm12) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm13, ymm3) vmulps(ymm0, ymm13, ymm13) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(scomplex) - lea(mem(, rsi, 4), rdx) // rdx = 4*cs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - - - - // now avoid loading C if beta == 0 + + + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -1978,186 +1620,74 @@ void bli_cgemm_haswell_asm_3x8 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CROWSTORED) // jump to row storage case - - - - label(.CGENSTORED) - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CROWSTORED) - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_RS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_RS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_RS - - - - jmp(.CDONE) // jump to end. - - - + + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) + + + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) + + + + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) + + + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) + + + + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) + + + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CROWSTORBZ) // jump to row storage case - - - - label(.CGENSTORBZ) - - - vmovaps(ymm4, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm5, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - vmovaps(ymm8, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm9, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - vmovaps(ymm12, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm13, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CROWSTORBZ) - - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rdx, 1)) - - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11, rdx, 1)) - - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) + + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) + label(.CDONE) - - - end_asm( + + vzeroupper() + + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2165,41 +1695,25 @@ void bli_cgemm_haswell_asm_3x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define ZGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ(where) \ + vmovupd(where, ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) - -// assumes values to output are in ymm0 -#define ZGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovupd(xmm0, mem(rcx)) \ - vmovupd(xmm3, mem(rcx, rsi, 1)) \ - -#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovupd(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) - -#define ZGEMM_OUTPUT_RS \ - vmovupd(ymm0, mem(rcx)) \ void bli_zgemm_haswell_asm_3x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2214,291 +1728,287 @@ void bli_zgemm_haswell_asm_3x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 3, 4, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*rs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*rs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - + // contains the k_left loop. + + label(.ZLOOPKITER) // MAIN LOOP - - - // iteration 0 + + + // iteration 0 prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - - // iteration 1 + + // iteration 1 + prefetch(0, mem(rax, 36*16)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - - // iteration 2 - prefetch(0, mem(rax, 38*16)) - + + // iteration 2 + prefetch(0, mem(rax, 40*16)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - - // iteration 3 + + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*3*16), rax) // a += 4*3 (unroll x mr) add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - + // else, we prepare to enter k_left loop. + + label(.ZLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*3*16), rax) // a += 1*3 (unroll x mr) add(imm(1*4*16), rbx) // b += 1*4 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) vpermilpd(imm(0x5), ymm7, ymm7) vpermilpd(imm(0x5), ymm10, ymm10) vpermilpd(imm(0x5), ymm11, ymm11) vpermilpd(imm(0x5), ymm14, ymm14) vpermilpd(imm(0x5), ymm15, ymm15) - - - // subtract/add even/odd elements + + + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) - + vaddsubpd(ymm10, ymm8, ymm8) vaddsubpd(ymm11, ymm9, ymm9) - + vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - - + + vpermilpd(imm(0x5), ymm4, ymm3) vmulpd(ymm0, ymm4, ymm4) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm4, ymm4) - + vpermilpd(imm(0x5), ymm5, ymm3) vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm5, ymm5) - - + + vpermilpd(imm(0x5), ymm8, ymm3) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm9, ymm3) vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm9, ymm9) - - + + vpermilpd(imm(0x5), ymm12, ymm3) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm13, ymm3) vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - - - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(, rsi, 2), rdx) // rdx = 2*cs_c; - - - - // now avoid loading C if beta == 0 + + + + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2506,186 +2016,74 @@ void bli_zgemm_haswell_asm_3x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZROWSTORED) // jump to row storage case - - - - label(.ZGENSTORED) - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZROWSTORED) - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - - - jmp(.ZDONE) // jump to end. - - - + + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) + + + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) + + + + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) + + + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) + + + + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) + + + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZROWSTORBZ) // jump to row storage case - - - - label(.ZGENSTORBZ) - - - vmovapd(ymm4, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm5, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - vmovapd(ymm8, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm9, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - vmovapd(ymm12, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm13, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZROWSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rdx, 1)) - - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11, rdx, 1)) - - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) + + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) + label(.ZDONE) - - - end_asm( + + vzeroupper() + + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2693,6 +2091,8 @@ void bli_zgemm_haswell_asm_3x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c index b074da965c..dd9526d566 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c @@ -78,7 +78,9 @@ void bli_sgemm_haswell_asm_16x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -93,29 +95,31 @@ void bli_sgemm_haswell_asm_16x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 16, 6, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c @@ -124,46 +128,46 @@ void bli_sgemm_haswell_asm_16x6 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 128*4)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, -2*32), ymm0) vmovaps(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rbx, 6*4), ymm2) vbroadcastss(mem(rbx, 7*4), ymm3) @@ -171,51 +175,51 @@ void bli_sgemm_haswell_asm_16x6 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 8*4), ymm2) vbroadcastss(mem(rbx, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 10*4), ymm2) vbroadcastss(mem(rbx, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 0*32), ymm0) vmovaps(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 152*4)) - + vbroadcastss(mem(rbx, 12*4), ymm2) vbroadcastss(mem(rbx, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 14*4), ymm2) vbroadcastss(mem(rbx, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 16*4), ymm2) vbroadcastss(mem(rbx, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 2*32), ymm0) vmovaps(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rbx, 18*4), ymm2) vbroadcastss(mem(rbx, 19*4), ymm3) @@ -223,91 +227,91 @@ void bli_sgemm_haswell_asm_16x6 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 20*4), ymm2) vbroadcastss(mem(rbx, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 22*4), ymm2) vbroadcastss(mem(rbx, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*16*4), rax) // a += 4*16 (unroll x mr) add(imm(4*6*4), rbx) // b += 4*6 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 128*4)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*16*4), rax) // a += 1*16 (unroll x mr) add(imm(1*6*4), rbx) // b += 1*6 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - + vmulps(ymm0, ymm4, ymm4) // scale by alpha vmulps(ymm0, ymm5, ymm5) vmulps(ymm0, ymm6, ymm6) @@ -320,315 +324,107 @@ void bli_sgemm_haswell_asm_16x6 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - - lea(mem(rcx, rsi, 8), rdx) // load address of c + 8*rs_c; - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c; - lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm4, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm6, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm8, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm10, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm12, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm14, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 8*rs_c - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm5, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm7, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm9, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm11, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm13, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm15, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm5) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm7) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm9) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm11) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm13) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm15) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.SDONE) // jump to end. - - - + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(ymm4, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm6, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm8, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm10, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm12, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm14, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 8*rs_c - - - vmovaps(ymm5, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm7, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm9, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm11, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm13, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm15, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + label(.SDONE) - - + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -636,6 +432,8 @@ void bli_sgemm_haswell_asm_16x6 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } #define DGEMM_INPUT_GS_BETA_NZ \ @@ -664,7 +462,9 @@ void bli_sgemm_haswell_asm_16x6 void bli_dgemm_haswell_asm_8x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -679,29 +479,31 @@ void bli_dgemm_haswell_asm_8x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 6, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c @@ -710,46 +512,46 @@ void bli_dgemm_haswell_asm_8x6 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, -2*32), ymm0) vmovapd(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rbx, 6*8), ymm2) vbroadcastsd(mem(rbx, 7*8), ymm3) @@ -757,51 +559,51 @@ void bli_dgemm_haswell_asm_8x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 8*8), ymm2) vbroadcastsd(mem(rbx, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 10*8), ymm2) vbroadcastsd(mem(rbx, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 0*32), ymm0) vmovapd(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 76*8)) - + vbroadcastsd(mem(rbx, 12*8), ymm2) vbroadcastsd(mem(rbx, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 14*8), ymm2) vbroadcastsd(mem(rbx, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 16*8), ymm2) vbroadcastsd(mem(rbx, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 2*32), ymm0) vmovapd(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rbx, 18*8), ymm2) vbroadcastsd(mem(rbx, 19*8), ymm3) @@ -809,91 +611,91 @@ void bli_dgemm_haswell_asm_8x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 20*8), ymm2) vbroadcastsd(mem(rbx, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 22*8), ymm2) vbroadcastsd(mem(rbx, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) add(imm(4*6*8), rbx) // b += 4*6 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*8*8), rax) // a += 1*8 (unroll x mr) add(imm(1*6*8), rbx) // b += 1*6 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -906,314 +708,107 @@ void bli_dgemm_haswell_asm_8x6 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - //lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c; - //lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm4, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm6, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm8, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm10, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm12, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm14, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 4*rs_c - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm5, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm7, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm9, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm11, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm13, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm15, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm5) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm11) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovapd(ymm4, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm6, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm8, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm10, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm12, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm14, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 4*rs_c - - - vmovapd(ymm5, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm7, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm9, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm11, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm13, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm15, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + label(.DDONE) - - - end_asm( + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1221,45 +816,25 @@ void bli_dgemm_haswell_asm_8x6 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define CGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovlpd(mem(rcx), xmm0, xmm0) \ - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ - vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ - vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define CGEMM_INPUT_SCALE_CS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilps(imm(0xb1), ymm0, ymm3) \ vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define CGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovlpd(xmm0, mem(rcx)) \ - vmovhpd(xmm0, mem(rcx, rsi, 1)) \ - vmovlpd(xmm3, mem(rcx, rsi, 2)) \ - vmovhpd(xmm3, mem(rcx, r13, 1)) - -#define CGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ - vpermilps(imm(0xb1), ymm0, ymm3) \ - vmulps(ymm1, ymm0, ymm0) \ - vmulps(ymm2, ymm3, ymm3) \ - vaddsubps(ymm3, ymm0, ymm0) - -#define CGEMM_OUTPUT_CS \ - vmovups(ymm0, mem(rcx)) \ - void bli_cgemm_haswell_asm_8x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1274,75 +849,77 @@ void bli_cgemm_haswell_asm_8x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 8, 3, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, -2*32), ymm0) vmovaps(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rbx, 6*4), ymm2) vbroadcastss(mem(rbx, 7*4), ymm3) @@ -1350,51 +927,51 @@ void bli_cgemm_haswell_asm_8x3 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 8*4), ymm2) vbroadcastss(mem(rbx, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 10*4), ymm2) vbroadcastss(mem(rbx, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 0*32), ymm0) vmovaps(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*8)) - + vbroadcastss(mem(rbx, 12*4), ymm2) vbroadcastss(mem(rbx, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 14*4), ymm2) vbroadcastss(mem(rbx, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 16*4), ymm2) vbroadcastss(mem(rbx, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 2*32), ymm0) vmovaps(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rbx, 18*4), ymm2) vbroadcastss(mem(rbx, 19*4), ymm3) @@ -1402,84 +979,84 @@ void bli_cgemm_haswell_asm_8x3 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 20*4), ymm2) vbroadcastss(mem(rbx, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 22*4), ymm2) vbroadcastss(mem(rbx, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) add(imm(4*3*8), rbx) // b += 4*3 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*8*8), rax) // a += 1*8 (unroll x mr) add(imm(1*3*8), rbx) // b += 1*3 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - - + + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) @@ -1488,76 +1065,68 @@ void bli_cgemm_haswell_asm_8x3 vpermilps(imm(0xb1), ymm11, ymm11) vpermilps(imm(0xb1), ymm14, ymm14) vpermilps(imm(0xb1), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) - + vaddsubps(ymm10, ymm8, ymm8) vaddsubps(ymm11, ymm9, ymm9) - + vaddsubps(ymm14, ymm12, ymm12) vaddsubps(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate - - + + vpermilps(imm(0xb1), ymm4, ymm3) vmulps(ymm0, ymm4, ymm4) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm4, ymm4) - + vpermilps(imm(0xb1), ymm5, ymm3) vmulps(ymm0, ymm5, ymm5) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm5, ymm5) - - + + vpermilps(imm(0xb1), ymm8, ymm3) vmulps(ymm0, ymm8, ymm8) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm9, ymm3) vmulps(ymm0, ymm9, ymm9) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm9, ymm9) - - + + vpermilps(imm(0xb1), ymm12, ymm3) vmulps(ymm0, ymm12, ymm12) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm13, ymm3) vmulps(ymm0, ymm13, ymm13) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - lea(mem(, rsi, 4), rdx) // rdx = 4*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + + + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1566,186 +1135,71 @@ void bli_cgemm_haswell_asm_8x3 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORED) // jump to row storage case - - - - label(.CGENSTORED) - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_CS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_CS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_CS - - - - jmp(.CDONE) // jump to end. - - - + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) + + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) + + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.CCOLSTORBZ) // jump to row storage case - - - - label(.CGENSTORBZ) - - - vmovaps(ymm4, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm5, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - vmovaps(ymm8, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm9, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - vmovaps(ymm12, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm13, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rdx, 1)) - - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11, rdx, 1)) - - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) + + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) + label(.CDONE) - - - end_asm( + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1753,41 +1207,25 @@ void bli_cgemm_haswell_asm_8x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define ZGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define ZGEMM_INPUT_SCALE_CS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) - -// assumes values to output are in ymm0 -#define ZGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovupd(xmm0, mem(rcx)) \ - vmovupd(xmm3, mem(rcx, rsi, 1)) \ - -#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) - -#define ZGEMM_OUTPUT_CS \ - vmovupd(ymm0, mem(rcx)) \ void bli_zgemm_haswell_asm_4x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -1802,76 +1240,78 @@ void bli_zgemm_haswell_asm_4x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 4, 3, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, -2*32), ymm0) vmovapd(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rbx, 6*8), ymm2) vbroadcastsd(mem(rbx, 7*8), ymm3) @@ -1879,51 +1319,51 @@ void bli_zgemm_haswell_asm_4x3 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 8*8), ymm2) vbroadcastsd(mem(rbx, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 10*8), ymm2) vbroadcastsd(mem(rbx, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 0*32), ymm0) vmovapd(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*16)) - + vbroadcastsd(mem(rbx, 12*8), ymm2) vbroadcastsd(mem(rbx, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 14*8), ymm2) vbroadcastsd(mem(rbx, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 16*8), ymm2) vbroadcastsd(mem(rbx, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 2*32), ymm0) vmovapd(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rbx, 18*8), ymm2) vbroadcastsd(mem(rbx, 19*8), ymm3) @@ -1931,83 +1371,83 @@ void bli_zgemm_haswell_asm_4x3 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 20*8), ymm2) vbroadcastsd(mem(rbx, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 22*8), ymm2) vbroadcastsd(mem(rbx, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) add(imm(4*3*16), rbx) // b += 4*3 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*4*16), rax) // a += 1*4 (unroll x mr) add(imm(1*3*16), rbx) // b += 1*3 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) @@ -2016,76 +1456,69 @@ void bli_zgemm_haswell_asm_4x3 vpermilpd(imm(0x5), ymm11, ymm11) vpermilpd(imm(0x5), ymm14, ymm14) vpermilpd(imm(0x5), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) - + vaddsubpd(ymm10, ymm8, ymm8) vaddsubpd(ymm11, ymm9, ymm9) - + vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - - + + vpermilpd(imm(0x5), ymm4, ymm3) vmulpd(ymm0, ymm4, ymm4) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm4, ymm4) - + vpermilpd(imm(0x5), ymm5, ymm3) vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm5, ymm5) - - + + vpermilpd(imm(0x5), ymm8, ymm3) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm9, ymm3) vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm9, ymm9) - - + + vpermilpd(imm(0x5), ymm12, ymm3) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm13, ymm3) vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(, rsi, 2), rdx) // rdx = 2*rs_c; - - - + + + + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -2094,171 +1527,56 @@ void bli_zgemm_haswell_asm_4x3 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORED) // jump to row storage case - - - - label(.ZGENSTORED) - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_CS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_CS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_CS - - - - jmp(.ZDONE) // jump to end. - - - + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) + + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) + + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORBZ) // jump to row storage case - - - - label(.ZGENSTORBZ) - - - vmovapd(ymm4, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm5, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - vmovapd(ymm8, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm9, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - vmovapd(ymm12, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm13, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rdx, 1)) - - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11, rdx, 1)) - - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) + + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) + label(.ZDONE) - - - end_asm( + + + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -2273,7 +1591,7 @@ void bli_zgemm_haswell_asm_4x3 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2281,6 +1599,8 @@ void bli_zgemm_haswell_asm_4x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c index ae3d67c5f3..d0d0ff2115 100644 --- a/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,6 +58,8 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 ( + dim_t m, + dim_t n, dim_t k0, float* restrict alpha, float* restrict a10, @@ -81,23 +83,25 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 float* beta = bli_sm1; + GEMMTRSM_UKR_SETUP_CT_ANY( s, 6, 16, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a10), rax) // load address of a. mov(var(b01), rbx) // load address of b. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(b11), rcx) // load address of b11 mov(imm(16), rdi) // set rs_b = PACKNR = 16 lea(mem(, rdi, 4), rdi) // rs_b *= sizeof(float) - + // NOTE: c11, rs_c, and cs_c aren't // needed for a while, but we load // them now to avoid stalling later. @@ -106,45 +110,45 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 lea(mem(, r9 , 4), r9) // rs_c *= sizeof(float) mov(var(k_left)0, r10) // load cs_c lea(mem(, r10, 4), r10) // cs_c *= sizeof(float) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) @@ -152,51 +156,51 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 76*4)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) @@ -204,144 +208,144 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*6*4), rax) // a += 4*6 (unroll x mr) add(imm(4*16*4), rbx) // b += 4*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*6*4), rax) // a += 1*6 (unroll x mr) add(imm(1*16*4), rbx) // b += 1*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + // ymm4..ymm15 = -a10 * b01 - - - + + + mov(var(alpha), rbx) // load address of alpha vbroadcastss(mem(rbx), ymm3) // load alpha and duplicate - - - - + + + + mov(imm(1), rsi) // load cs_b = 1 lea(mem(, rsi, 4), rsi) // cs_b *= sizeof(float) - + lea(mem(rcx, rsi, 8), rdx) // load address of b11 + 8*cs_b - + mov(rcx, r11) // save rcx = b11 for later mov(rdx, r14) // save rdx = b11+8*cs_b for later - - + + // b11 := alpha * b11 - a10 * b01 vfmsub231ps(mem(rcx), ymm3, ymm4) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm5) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm6) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm7) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm8) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm9) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm10) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm11) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm12) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm13) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm14) //add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm15) //add(rdi, rdx) - - - + + + // prefetch c11 - + #if 0 mov(r8, rcx) // load address of c11 from r8 // Note: r9 = rs_c * sizeof(float) - + lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c; - + prefetch(0, mem(rcx, 0*8)) // prefetch c11 + 0*rs_c prefetch(0, mem(rcx, r9, 1, 0*8)) // prefetch c11 + 1*rs_c prefetch(0, mem(rcx, r9 , 2, 0*8)) // prefetch c11 + 2*rs_c @@ -349,12 +353,12 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 prefetch(0, mem(rdx, r9, 1, 0*8)) // prefetch c11 + 4*rs_c prefetch(0, mem(rdx, r9 , 2, 0*8)) // prefetch c11 + 5*rs_c #endif - - - - + + + + // trsm computation begins here - + // Note: contents of b11 are stored as // ymm4 ymm5 = ( beta00..07 ) ( beta08..0F ) // ymm6 ymm7 = ( beta10..17 ) ( beta18..1F ) @@ -362,348 +366,378 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 // ymm10 ymm11 = ( beta30..37 ) ( beta38..3F ) // ymm12 ymm13 = ( beta40..47 ) ( beta48..4F ) // ymm14 ymm15 = ( beta50..57 ) ( beta58..5F ) - - + + mov(var(a11), rax) // load address of a11 - + mov(r11, rcx) // recall address of b11 mov(r14, rdx) // recall address of b11+8*cs_b // Note: rdi = rs_b - + // iteration 0 ------------- - + vbroadcastss(mem(0+0*6)*4(rax), ymm0) // ymm0 = (1/alpha00) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm4, ymm4) // ymm4 *= (1/alpha00) vmulps(ymm0, ymm5, ymm5) // ymm5 *= (1/alpha00) - +#else + vdivps(ymm0, ymm4, ymm4) // ymm4 /= alpha00 + vdivps(ymm0, ymm5, ymm5) // ymm5 /= alpha00 +#endif + vmovups(ymm4, mem(rcx)) // store ( beta00..beta07 ) = ymm4 vmovups(ymm5, mem(rdx)) // store ( beta08..beta0F ) = ymm5 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 1 ------------- - + vbroadcastss(mem(1+0*6)*4(rax), ymm0) // ymm0 = alpha10 vbroadcastss(mem(1+1*6)*4(rax), ymm1) // ymm1 = (1/alpha11) - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha10 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha10 * ymm5 - + vsubps(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubps(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - - vmulps(ymm6, ymm1, ymm6) // ymm6 *= (1/alpha11) - vmulps(ymm7, ymm1, ymm7) // ymm7 *= (1/alpha11) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulps(ymm1, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivps(ymm1, ymm6, ymm6) // ymm6 /= alpha11 + vdivps(ymm1, ymm7, ymm7) // ymm7 /= alpha11 +#endif + vmovups(ymm6, mem(rcx)) // store ( beta10..beta17 ) = ymm6 vmovups(ymm7, mem(rdx)) // store ( beta18..beta1F ) = ymm7 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 2 ------------- - + vbroadcastss(mem(2+0*6)*4(rax), ymm0) // ymm0 = alpha20 vbroadcastss(mem(2+1*6)*4(rax), ymm1) // ymm1 = alpha21 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha20 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha20 * ymm5 - + vbroadcastss(mem(2+2*6)*4(rax), ymm0) // ymm0 = (1/alpha22) - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha21 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha21 * ymm7 - + vsubps(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubps(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - - vmulps(ymm8, ymm0, ymm8) // ymm8 *= (1/alpha22) - vmulps(ymm9, ymm0, ymm9) // ymm9 *= (1/alpha22) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulps(ymm0, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivps(ymm0, ymm8, ymm8) // ymm8 /= alpha22 + vdivps(ymm0, ymm9, ymm9) // ymm9 /= alpha22 +#endif + vmovups(ymm8, mem(rcx)) // store ( beta20..beta27 ) = ymm8 vmovups(ymm9, mem(rdx)) // store ( beta28..beta2F ) = ymm9 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 3 ------------- - + vbroadcastss(mem(3+0*6)*4(rax), ymm0) // ymm0 = alpha30 vbroadcastss(mem(3+1*6)*4(rax), ymm1) // ymm1 = alpha31 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha30 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha30 * ymm5 - + vbroadcastss(mem(3+2*6)*4(rax), ymm0) // ymm0 = alpha32 - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha31 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha31 * ymm7 - + vbroadcastss(mem(3+3*6)*4(rax), ymm1) // ymm0 = (1/alpha33) - + vfmadd231ps(ymm0, ymm8, ymm2) // ymm2 += alpha32 * ymm8 vfmadd231ps(ymm0, ymm9, ymm3) // ymm3 += alpha32 * ymm9 - + vsubps(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubps(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - - vmulps(ymm10, ymm1, ymm10) // ymm10 *= (1/alpha33) - vmulps(ymm11, ymm1, ymm11) // ymm11 *= (1/alpha33) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulps(ymm1, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivps(ymm1, ymm10, ymm10) // ymm10 /= alpha33 + vdivps(ymm1, ymm11, ymm11) // ymm11 /= alpha33 +#endif + vmovups(ymm10, mem(rcx)) // store ( beta30..beta37 ) = ymm10 vmovups(ymm11, mem(rdx)) // store ( beta38..beta3F ) = ymm11 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 4 ------------- - + vbroadcastss(mem(4+0*6)*4(rax), ymm0) // ymm0 = alpha40 vbroadcastss(mem(4+1*6)*4(rax), ymm1) // ymm1 = alpha41 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha40 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha40 * ymm5 - + vbroadcastss(mem(4+2*6)*4(rax), ymm0) // ymm0 = alpha42 - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha41 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha41 * ymm7 - + vbroadcastss(mem(4+3*6)*4(rax), ymm1) // ymm1 = alpha43 - + vfmadd231ps(ymm0, ymm8, ymm2) // ymm2 += alpha42 * ymm8 vfmadd231ps(ymm0, ymm9, ymm3) // ymm3 += alpha42 * ymm9 - + vbroadcastss(mem(4+4*6)*4(rax), ymm0) // ymm0 = (1/alpha44) - + vfmadd231ps(ymm1, ymm10, ymm2) // ymm2 += alpha43 * ymm10 vfmadd231ps(ymm1, ymm11, ymm3) // ymm3 += alpha43 * ymm11 - + vsubps(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubps(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - - vmulps(ymm12, ymm0, ymm12) // ymm12 *= (1/alpha44) - vmulps(ymm13, ymm0, ymm13) // ymm13 *= (1/alpha44) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulps(ymm0, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivps(ymm0, ymm12, ymm12) // ymm12 /= alpha44 + vdivps(ymm0, ymm13, ymm13) // ymm13 /= alpha44 +#endif + vmovups(ymm12, mem(rcx)) // store ( beta40..beta47 ) = ymm12 vmovups(ymm13, mem(rdx)) // store ( beta48..beta4F ) = ymm13 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 5 ------------- - + vbroadcastss(mem(5+0*6)*4(rax), ymm0) // ymm0 = alpha50 vbroadcastss(mem(5+1*6)*4(rax), ymm1) // ymm1 = alpha51 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha50 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha50 * ymm5 - + vbroadcastss(mem(5+2*6)*4(rax), ymm0) // ymm0 = alpha52 - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha51 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha51 * ymm7 - + vbroadcastss(mem(5+3*6)*4(rax), ymm1) // ymm1 = alpha53 - + vfmadd231ps(ymm0, ymm8, ymm2) // ymm2 += alpha52 * ymm8 vfmadd231ps(ymm0, ymm9, ymm3) // ymm3 += alpha52 * ymm9 - + vbroadcastss(mem(5+4*6)*4(rax), ymm0) // ymm0 = alpha54 - + vfmadd231ps(ymm1, ymm10, ymm2) // ymm2 += alpha53 * ymm10 vfmadd231ps(ymm1, ymm11, ymm3) // ymm3 += alpha53 * ymm11 - + vbroadcastss(mem(5+5*6)*4(rax), ymm1) // ymm1 = (1/alpha55) - + vfmadd231ps(ymm0, ymm12, ymm2) // ymm2 += alpha54 * ymm12 vfmadd231ps(ymm0, ymm13, ymm3) // ymm3 += alpha54 * ymm13 - + vsubps(ymm2, ymm14, ymm14) // ymm14 -= ymm2 vsubps(ymm3, ymm15, ymm15) // ymm15 -= ymm3 - - vmulps(ymm14, ymm1, ymm14) // ymm14 *= (1/alpha55) - vmulps(ymm15, ymm1, ymm15) // ymm15 *= (1/alpha55) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm14, ymm14) // ymm14 *= (1/alpha55) + vmulps(ymm1, ymm15, ymm15) // ymm15 *= (1/alpha55) +#else + vdivps(ymm1, ymm14, ymm14) // ymm14 /= alpha55 + vdivps(ymm1, ymm15, ymm15) // ymm15 /= alpha55 +#endif + vmovups(ymm14, mem(rcx)) // store ( beta50..beta57 ) = ymm14 vmovups(ymm15, mem(rdx)) // store ( beta58..beta5F ) = ymm15 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - - - - - + + + + + mov(r8, rcx) // load address of c11 from r8 mov(r9, rdi) // load rs_c (in bytes) from r9 mov(r10, rsi) // load cs_c (in bytes) from r10 - + lea(mem(rcx, rsi, 8), rdx) // load address of c11 + 8*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c; - + // These are used in the macros below. lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - + + + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. jz(.SROWSTORED) // jump to row storage case - - - + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. jz(.SCOLSTORED) // jump to column storage case - - - + + + // if neither row- or column- // stored, use general case. label(.SGENSTORED) - - + + vmovaps(ymm4, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm6, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm8, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm10, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm12, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm14, ymm0) SGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c11 + 8*cs_c - - + + vmovaps(ymm5, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm7, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm9, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm11, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm13, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm15, ymm0) SGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.SDONE) - - - + + + label(.SROWSTORED) - - + + vmovups(ymm4, mem(rcx)) add(rdi, rcx) vmovups(ymm5, mem(rdx)) add(rdi, rdx) - + vmovups(ymm6, mem(rcx)) add(rdi, rcx) vmovups(ymm7, mem(rdx)) add(rdi, rdx) - + vmovups(ymm8, mem(rcx)) add(rdi, rcx) vmovups(ymm9, mem(rdx)) add(rdi, rdx) - + vmovups(ymm10, mem(rcx)) add(rdi, rcx) vmovups(ymm11, mem(rdx)) add(rdi, rdx) - + vmovups(ymm12, mem(rcx)) add(rdi, rcx) vmovups(ymm13, mem(rdx)) add(rdi, rdx) - + vmovups(ymm14, mem(rcx)) //add(rdi, rcx) vmovups(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.SDONE) - - - + + + label(.SCOLSTORED) - - + + vunpcklps(ymm6, ymm4, ymm0) vunpcklps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm6, ymm4, ymm0) vunpckhps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm14, ymm12, ymm0) vunpckhps(ymm14, ymm12, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) @@ -712,46 +746,46 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - + + vunpcklps(ymm7, ymm5, ymm0) vunpcklps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx)) // store ( gamma08..gamma38 ) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma09..gamma39 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma0C..gamma3C ) vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma0D..gamma3D ) - + vunpckhps(ymm7, ymm5, ymm0) vunpckhps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma0A..gamma3A ) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma0B..gamma3B ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma0E..gamma3E ) vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma0F..gamma3F ) - + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm15, ymm13, ymm0) vunpckhps(ymm15, ymm13, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovlpd(xmm0, mem(r14)) // store ( gamma48..gamma58 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma49..gamma59 ) vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma4A..gamma5A ) @@ -760,33 +794,34 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma4D..gamma5D ) vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma4E..gamma5E ) vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma4F..gamma5F ) - + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - + + + + label(.SDONE) - + vzeroupper() - + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a10] "m" (a10), // 2 - [b01] "m" (b01), // 3 - [beta] "m" (beta), // 4 - [alpha] "m" (alpha), // 5 - [a11] "m" (a11), // 6 - [b11] "m" (b11), // 7 - [c11] "m" (c11), // 8 - [rs_c] "m" (rs_c), // 9 - [cs_c] "m" (cs_c) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a10] "m" (a10), // 2 + [b01] "m" (b01), // 3 + [beta] "m" (beta), // 4 + [alpha] "m" (alpha), // 5 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -794,6 +829,8 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMMTRSM_UKR_FLUSH_CT( s ); } @@ -811,17 +848,19 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vmovhpd(xmm1, mem(rcx, r10, 1))*/ void bli_dgemmtrsm_l_haswell_asm_6x8 -( - dim_t k0, - double* restrict alpha, - double* restrict a10, - double* restrict a11, - double* restrict b01, - double* restrict b11, - double* restrict c11, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx -) + ( + dim_t m, + dim_t n, + dim_t k0, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) { //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); @@ -835,23 +874,25 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 double* beta = bli_dm1; + GEMMTRSM_UKR_SETUP_CT_ANY( d, 6, 8, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a10), rax) // load address of a. mov(var(b01), rbx) // load address of b. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(b11), rcx) // load address of b11 mov(imm(8), rdi) // set rs_b = PACKNR = 8 lea(mem(, rdi, 8), rdi) // rs_b *= sizeof(double) - + // NOTE: c11, rs_c, and cs_c aren't // needed for a while, but we load // them now to avoid stalling later. @@ -860,97 +901,99 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 lea(mem(, r9 , 8), r9) // rs_c *= sizeof(double) mov(var(k_left)0, r10) // load cs_c lea(mem(, r10, 8), r10) // cs_c *= sizeof(double) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 - prefetch(0, mem(rax, 76*8)) - + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -958,145 +1001,145 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + // ymm4..ymm15 = -a10 * b01 - - - - + + + + mov(var(alpha), rbx) // load address of alpha vbroadcastsd(mem(rbx), ymm3) // load alpha and duplicate - - - - + + + + mov(imm(1), rsi) // set cs_b = 1 lea(mem(, rsi, 8), rsi) // cs_b *= sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of b11 + 4*cs_b - + mov(rcx, r11) // save rcx = b11 for later mov(rdx, r14) // save rdx = b11+4*cs_b for later - - + + // b11 := alpha * b11 - a10 * b01 vfmsub231pd(mem(rcx), ymm3, ymm4) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm5) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm6) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm7) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm8) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm9) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm10) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm11) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm12) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm13) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm14) //add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm15) //add(rdi, rdx) - - - + + + // prefetch c11 - + #if 0 mov(r8, rcx) // load address of c11 from r8 // Note: r9 = rs_c * sizeof(double) - + lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c11 + 0*rs_c prefetch(0, mem(rcx, r9, 1, 7*8)) // prefetch c11 + 1*rs_c prefetch(0, mem(rcx, r9 , 2, 7*8)) // prefetch c11 + 2*rs_c @@ -1104,12 +1147,12 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 prefetch(0, mem(rdx, r9, 1, 7*8)) // prefetch c11 + 4*rs_c prefetch(0, mem(rdx, r9 , 2, 7*8)) // prefetch c11 + 5*rs_c #endif - - - - + + + + // trsm computation begins here - + // Note: contents of b11 are stored as // ymm4 ymm5 = ( beta00..03 ) ( beta04..07 ) // ymm6 ymm7 = ( beta10..13 ) ( beta14..17 ) @@ -1117,309 +1160,339 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 // ymm10 ymm11 = ( beta30..33 ) ( beta34..37 ) // ymm12 ymm13 = ( beta40..43 ) ( beta44..47 ) // ymm14 ymm15 = ( beta50..53 ) ( beta54..57 ) - - + + mov(var(a11), rax) // load address of a11 - + mov(r11, rcx) // recall address of b11 mov(r14, rdx) // recall address of b11+4*cs_b // Note: rdi = rs_b - + // iteration 0 ------------- - + vbroadcastsd(mem(0+0*6)*8(rax), ymm0) // ymm0 = (1/alpha00) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm4, ymm4) // ymm4 *= (1/alpha00) vmulpd(ymm0, ymm5, ymm5) // ymm5 *= (1/alpha00) - +#else + vdivpd(ymm0, ymm4, ymm4) // ymm4 /= alpha00 + vdivpd(ymm0, ymm5, ymm5) // ymm5 /= alpha00 +#endif + vmovupd(ymm4, mem(rcx)) // store ( beta00..beta03 ) = ymm4 vmovupd(ymm5, mem(rdx)) // store ( beta04..beta07 ) = ymm5 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 1 ------------- - + vbroadcastsd(mem(1+0*6)*8(rax), ymm0) // ymm0 = alpha10 vbroadcastsd(mem(1+1*6)*8(rax), ymm1) // ymm1 = (1/alpha11) - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha10 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha10 * ymm5 - + vsubpd(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubpd(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - - vmulpd(ymm6, ymm1, ymm6) // ymm6 *= (1/alpha11) - vmulpd(ymm7, ymm1, ymm7) // ymm7 *= (1/alpha11) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulpd(ymm1, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivpd(ymm1, ymm6, ymm6) // ymm6 /= alpha11 + vdivpd(ymm1, ymm7, ymm7) // ymm7 /= alpha11 +#endif + vmovupd(ymm6, mem(rcx)) // store ( beta10..beta13 ) = ymm6 vmovupd(ymm7, mem(rdx)) // store ( beta14..beta17 ) = ymm7 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 2 ------------- - + vbroadcastsd(mem(2+0*6)*8(rax), ymm0) // ymm0 = alpha20 vbroadcastsd(mem(2+1*6)*8(rax), ymm1) // ymm1 = alpha21 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha20 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha20 * ymm5 - + vbroadcastsd(mem(2+2*6)*8(rax), ymm0) // ymm0 = (1/alpha22) - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha21 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha21 * ymm7 - + vsubpd(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubpd(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - - vmulpd(ymm8, ymm0, ymm8) // ymm8 *= (1/alpha22) - vmulpd(ymm9, ymm0, ymm9) // ymm9 *= (1/alpha22) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulpd(ymm0, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivpd(ymm0, ymm8, ymm8) // ymm8 /= alpha22 + vdivpd(ymm0, ymm9, ymm9) // ymm9 /= alpha22 +#endif + vmovupd(ymm8, mem(rcx)) // store ( beta20..beta23 ) = ymm8 vmovupd(ymm9, mem(rdx)) // store ( beta24..beta27 ) = ymm9 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 3 ------------- - + vbroadcastsd(mem(3+0*6)*8(rax), ymm0) // ymm0 = alpha30 vbroadcastsd(mem(3+1*6)*8(rax), ymm1) // ymm1 = alpha31 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha30 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha30 * ymm5 - + vbroadcastsd(mem(3+2*6)*8(rax), ymm0) // ymm0 = alpha32 - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha31 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha31 * ymm7 - + vbroadcastsd(mem(3+3*6)*8(rax), ymm1) // ymm1 = (1/alpha33) - + vfmadd231pd(ymm0, ymm8, ymm2) // ymm2 += alpha32 * ymm8 vfmadd231pd(ymm0, ymm9, ymm3) // ymm3 += alpha32 * ymm9 - + vsubpd(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubpd(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - - vmulpd(ymm10, ymm1, ymm10) // ymm10 *= (1/alpha33) - vmulpd(ymm11, ymm1, ymm11) // ymm11 *= (1/alpha33) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulpd(ymm1, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivpd(ymm1, ymm10, ymm10) // ymm10 /= alpha33 + vdivpd(ymm1, ymm11, ymm11) // ymm11 /= alpha33 +#endif + vmovupd(ymm10, mem(rcx)) // store ( beta30..beta33 ) = ymm10 vmovupd(ymm11, mem(rdx)) // store ( beta34..beta37 ) = ymm11 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 4 ------------- - + vbroadcastsd(mem(4+0*6)*8(rax), ymm0) // ymm0 = alpha40 vbroadcastsd(mem(4+1*6)*8(rax), ymm1) // ymm1 = alpha41 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha40 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha40 * ymm5 - + vbroadcastsd(mem(4+2*6)*8(rax), ymm0) // ymm0 = alpha42 - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha41 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha41 * ymm7 - + vbroadcastsd(mem(4+3*6)*8(rax), ymm1) // ymm1 = alpha43 - + vfmadd231pd(ymm0, ymm8, ymm2) // ymm2 += alpha42 * ymm8 vfmadd231pd(ymm0, ymm9, ymm3) // ymm3 += alpha42 * ymm9 - + vbroadcastsd(mem(4+4*6)*8(rax), ymm0) // ymm4 = (1/alpha44) - + vfmadd231pd(ymm1, ymm10, ymm2) // ymm2 += alpha43 * ymm10 vfmadd231pd(ymm1, ymm11, ymm3) // ymm3 += alpha43 * ymm11 - + vsubpd(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubpd(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - - vmulpd(ymm12, ymm0, ymm12) // ymm12 *= (1/alpha44) - vmulpd(ymm13, ymm0, ymm13) // ymm13 *= (1/alpha44) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulpd(ymm0, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivpd(ymm0, ymm12, ymm12) // ymm12 /= alpha44 + vdivpd(ymm0, ymm13, ymm13) // ymm13 /= alpha44 +#endif + vmovupd(ymm12, mem(rcx)) // store ( beta40..beta43 ) = ymm12 vmovupd(ymm13, mem(rdx)) // store ( beta44..beta47 ) = ymm13 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 5 ------------- - + vbroadcastsd(mem(5+0*6)*8(rax), ymm0) // ymm0 = alpha50 vbroadcastsd(mem(5+1*6)*8(rax), ymm1) // ymm1 = alpha51 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha50 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha50 * ymm5 - + vbroadcastsd(mem(5+2*6)*8(rax), ymm0) // ymm0 = alpha52 - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha51 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha51 * ymm7 - + vbroadcastsd(mem(5+3*6)*8(rax), ymm1) // ymm1 = alpha53 - + vfmadd231pd(ymm0, ymm8, ymm2) // ymm2 += alpha52 * ymm8 vfmadd231pd(ymm0, ymm9, ymm3) // ymm3 += alpha52 * ymm9 - + vbroadcastsd(mem(5+4*6)*8(rax), ymm0) // ymm0 = alpha54 - + vfmadd231pd(ymm1, ymm10, ymm2) // ymm2 += alpha53 * ymm10 vfmadd231pd(ymm1, ymm11, ymm3) // ymm3 += alpha53 * ymm11 - + vbroadcastsd(mem(5+5*6)*8(rax), ymm1) // ymm1 = (1/alpha55) - + vfmadd231pd(ymm0, ymm12, ymm2) // ymm2 += alpha54 * ymm12 vfmadd231pd(ymm0, ymm13, ymm3) // ymm3 += alpha54 * ymm13 - + vsubpd(ymm2, ymm14, ymm14) // ymm14 -= ymm2 vsubpd(ymm3, ymm15, ymm15) // ymm15 -= ymm3 - - vmulpd(ymm14, ymm1, ymm14) // ymm14 *= (1/alpha55) - vmulpd(ymm15, ymm1, ymm15) // ymm15 *= (1/alpha55) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm14, ymm14) // ymm14 *= (1/alpha55) + vmulpd(ymm1, ymm15, ymm15) // ymm15 *= (1/alpha55) +#else + vdivpd(ymm1, ymm14, ymm14) // ymm14 /= alpha55 + vdivpd(ymm1, ymm15, ymm15) // ymm15 /= alpha55 +#endif + vmovupd(ymm14, mem(rcx)) // store ( beta50..beta53 ) = ymm14 vmovupd(ymm15, mem(rdx)) // store ( beta54..beta57 ) = ymm15 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - - - - + + + + mov(r8, rcx) // load address of c11 from r8 mov(r9, rdi) // load rs_c (in bytes) from r9 mov(r10, rsi) // load cs_c (in bytes) from r10 - + lea(mem(rcx, rsi, 4), rdx) // load address of c11 + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c; - + // These are used in the macros below. lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - + + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORED) // jump to row storage case - - - + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - + + + // if neither row- or column- // stored, use general case. label(.DGENSTORED) - - + + vmovapd(ymm4, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm6, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm8, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm10, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm12, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm14, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c11 + 4*cs_c - - + + vmovapd(ymm5, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm7, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm9, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm11, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm13, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm15, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + jmp(.DDONE) - - - + + + label(.DROWSTORED) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.DDONE) - - - + + + label(.DCOLSTORED) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1428,27 +1501,27 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm3, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1457,50 +1530,49 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm3, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - - - + + + + label(.DDONE) - + vzeroupper() - + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a10] "m" (a10), // 2 - [b01] "m" (b01), // 3 - [beta] "m" (beta), // 4 - [alpha] "m" (alpha), // 5 - [a11] "m" (a11), // 6 - [b11] "m" (b11), // 7 - [c11] "m" (c11), // 8 - [rs_c] "m" (rs_c), // 9 - [cs_c] "m" (cs_c) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a10] "m" (a10), // 2 + [b01] "m" (b01), // 3 + [beta] "m" (beta), // 4 + [alpha] "m" (alpha), // 5 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1510,6 +1582,8 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMMTRSM_UKR_FLUSH_CT( d ); } diff --git a/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c index ceb4e1e5b8..68a8c069b4 100644 --- a/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,6 +58,8 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 ( + dim_t m, + dim_t n, dim_t k0, float* restrict alpha, float* restrict a10, @@ -81,23 +83,25 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 float* beta = bli_sm1; + GEMMTRSM_UKR_SETUP_CT_ANY( s, 6, 16, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a10), rax) // load address of a. mov(var(b01), rbx) // load address of b. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(b11), rcx) // load address of b11 mov(imm(16), rdi) // set rs_b = PACKNR = 16 lea(mem(, rdi, 4), rdi) // rs_b *= sizeof(float) - + // NOTE: c11, rs_c, and cs_c aren't // needed for a while, but we load // them now to avoid stalling later. @@ -106,45 +110,45 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 lea(mem(, r9 , 4), r9) // rs_c *= sizeof(float) mov(var(k_left)0, r10) // load cs_c lea(mem(, r10, 4), r10) // cs_c *= sizeof(float) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) @@ -152,51 +156,51 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 76*4)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) @@ -204,144 +208,144 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*6*4), rax) // a += 4*6 (unroll x mr) add(imm(4*16*4), rbx) // b += 4*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*6*4), rax) // a += 1*6 (unroll x mr) add(imm(1*16*4), rbx) // b += 1*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + // ymm4..ymm15 = -a10 * b01 - - - + + + mov(var(alpha), rbx) // load address of alpha vbroadcastss(mem(rbx), ymm3) // load alpha and duplicate - - - - + + + + mov(imm(1), rsi) // load cs_b = 1 lea(mem(, rsi, 4), rsi) // cs_b *= sizeof(float) - + lea(mem(rcx, rsi, 8), rdx) // load address of b11 + 8*cs_b - + mov(rcx, r11) // save rcx = b11 for later mov(rdx, r14) // save rdx = b11+8*cs_b for later - - + + // b11 := alpha * b11 - a10 * b01 vfmsub231ps(mem(rcx), ymm3, ymm4) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm5) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm6) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm7) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm8) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm9) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm10) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm11) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm12) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm13) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm14) //add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm15) //add(rdi, rdx) - - - + + + // prefetch c11 - + #if 0 mov(r8, rcx) // load address of c11 from r8 // Note: r9 = rs_c * sizeof(float) - + lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c; - + prefetch(0, mem(rcx, 0*8)) // prefetch c11 + 0*rs_c prefetch(0, mem(rcx, r9, 1, 0*8)) // prefetch c11 + 1*rs_c prefetch(0, mem(rcx, r9 , 2, 0*8)) // prefetch c11 + 2*rs_c @@ -349,12 +353,12 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 prefetch(0, mem(rdx, r9, 1, 0*8)) // prefetch c11 + 4*rs_c prefetch(0, mem(rdx, r9 , 2, 0*8)) // prefetch c11 + 5*rs_c #endif - - - - + + + + // trsm computation begins here - + // Note: contents of b11 are stored as // ymm4 ymm5 = ( beta00..07 ) ( beta08..0F ) // ymm6 ymm7 = ( beta10..17 ) ( beta18..1F ) @@ -362,353 +366,383 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 // ymm10 ymm11 = ( beta30..37 ) ( beta38..3F ) // ymm12 ymm13 = ( beta40..47 ) ( beta48..4F ) // ymm14 ymm15 = ( beta50..57 ) ( beta58..5F ) - - + + mov(var(a11), rax) // load address of a11 - + mov(r11, rcx) // recall address of b11 mov(r14, rdx) // recall address of b11+8*cs_b - + lea(mem(rcx, rdi, 4), rcx) // rcx = b11 + (6-1)*rs_b lea(mem(rcx, rdi, 1), rcx) lea(mem(rdx, rdi, 4), rdx) // rdx = b11 + (6-1)*rs_b + 8*cs_b lea(mem(rdx, rdi, 1), rdx) - - + + // iteration 0 ------------- - + vbroadcastss(mem(5+5*6)*4(rax), ymm0) // ymm0 = (1/alpha55) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm14, ymm14) // ymm14 *= (1/alpha55) vmulps(ymm0, ymm15, ymm15) // ymm15 *= (1/alpha55) - +#else + vdivps(ymm0, ymm14, ymm14) // ymm14 /= alpha55 + vdivps(ymm0, ymm15, ymm15) // ymm15 /= alpha55 +#endif + vmovups(ymm14, mem(rcx)) // store ( beta50..beta57 ) = ymm14 vmovups(ymm15, mem(rdx)) // store ( beta58..beta5F ) = ymm15 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 1 ------------- - + vbroadcastss(mem(4+5*6)*4(rax), ymm0) // ymm0 = alpha45 vbroadcastss(mem(4+4*6)*4(rax), ymm1) // ymm1 = (1/alpha44) - + vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha45 * ymm14 vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha45 * ymm15 - + vsubps(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubps(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - - vmulps(ymm12, ymm1, ymm12) // ymm12 *= (1/alpha44) - vmulps(ymm13, ymm1, ymm13) // ymm13 *= (1/alpha44) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulps(ymm1, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivps(ymm1, ymm12, ymm12) // ymm12 /= alpha44 + vdivps(ymm1, ymm13, ymm13) // ymm13 /= alpha44 +#endif + vmovups(ymm12, mem(rcx)) // store ( beta40..beta47 ) = ymm12 vmovups(ymm13, mem(rdx)) // store ( beta48..beta4F ) = ymm13 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 2 ------------- - + vbroadcastss(mem(3+5*6)*4(rax), ymm0) // ymm0 = alpha35 vbroadcastss(mem(3+4*6)*4(rax), ymm1) // ymm1 = alpha34 - + vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha35 * ymm14 vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha35 * ymm15 - + vbroadcastss(mem(3+3*6)*4(rax), ymm0) // ymm0 = (1/alpha33) - + vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha34 * ymm12 vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha34 * ymm13 - + vsubps(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubps(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - - vmulps(ymm10, ymm0, ymm10) // ymm10 *= (1/alpha33) - vmulps(ymm11, ymm0, ymm11) // ymm11 *= (1/alpha33) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulps(ymm0, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivps(ymm0, ymm10, ymm10) // ymm10 /= alpha33 + vdivps(ymm0, ymm11, ymm11) // ymm11 /= alpha33 +#endif + vmovups(ymm10, mem(rcx)) // store ( beta30..beta37 ) = ymm10 vmovups(ymm11, mem(rdx)) // store ( beta38..beta3F ) = ymm11 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 3 ------------- - + vbroadcastss(mem(2+5*6)*4(rax), ymm0) // ymm0 = alpha25 vbroadcastss(mem(2+4*6)*4(rax), ymm1) // ymm1 = alpha24 - + vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha25 * ymm14 vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha25 * ymm15 - + vbroadcastss(mem(2+3*6)*4(rax), ymm0) // ymm0 = alpha23 - + vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha24 * ymm12 vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha24 * ymm13 - + vbroadcastss(mem(2+2*6)*4(rax), ymm1) // ymm1 = (1/alpha22) - + vfmadd231ps(ymm0, ymm10, ymm2) // ymm2 += alpha23 * ymm10 vfmadd231ps(ymm0, ymm11, ymm3) // ymm3 += alpha23 * ymm11 - + vsubps(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubps(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - - vmulps(ymm8, ymm1, ymm8) // ymm8 *= (1/alpha33) - vmulps(ymm9, ymm1, ymm9) // ymm9 *= (1/alpha33) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulps(ymm1, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivps(ymm1, ymm8, ymm8) // ymm8 /= alpha22 + vdivps(ymm1, ymm9, ymm9) // ymm9 /= alpha22 +#endif + vmovups(ymm8, mem(rcx)) // store ( beta20..beta27 ) = ymm8 vmovups(ymm9, mem(rdx)) // store ( beta28..beta2F ) = ymm9 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 4 ------------- - + vbroadcastss(mem(1+5*6)*4(rax), ymm0) // ymm0 = alpha15 vbroadcastss(mem(1+4*6)*4(rax), ymm1) // ymm1 = alpha14 - + vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha15 * ymm14 vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha15 * ymm15 - + vbroadcastss(mem(1+3*6)*4(rax), ymm0) // ymm0 = alpha13 - + vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha14 * ymm12 vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha14 * ymm13 - + vbroadcastss(mem(1+2*6)*4(rax), ymm1) // ymm1 = alpha12 - + vfmadd231ps(ymm0, ymm10, ymm2) // ymm2 += alpha13 * ymm10 vfmadd231ps(ymm0, ymm11, ymm3) // ymm3 += alpha13 * ymm11 - + vbroadcastss(mem(1+1*6)*4(rax), ymm0) // ymm4 = (1/alpha11) - + vfmadd231ps(ymm1, ymm8, ymm2) // ymm2 += alpha12 * ymm8 vfmadd231ps(ymm1, ymm9, ymm3) // ymm3 += alpha12 * ymm9 - + vsubps(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubps(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - - vmulps(ymm6, ymm0, ymm6) // ymm6 *= (1/alpha44) - vmulps(ymm7, ymm0, ymm7) // ymm7 *= (1/alpha44) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulps(ymm0, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivps(ymm0, ymm6, ymm6) // ymm6 /= alpha11 + vdivps(ymm0, ymm7, ymm7) // ymm7 /= alpha11 +#endif + vmovups(ymm6, mem(rcx)) // store ( beta10..beta17 ) = ymm6 vmovups(ymm7, mem(rdx)) // store ( beta18..beta1F ) = ymm7 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 5 ------------- - + vbroadcastss(mem(0+5*6)*4(rax), ymm0) // ymm0 = alpha05 vbroadcastss(mem(0+4*6)*4(rax), ymm1) // ymm1 = alpha04 - + vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha05 * ymm14 vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha05 * ymm15 - + vbroadcastss(mem(0+3*6)*4(rax), ymm0) // ymm0 = alpha03 - + vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha04 * ymm12 vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha04 * ymm13 - + vbroadcastss(mem(0+2*6)*4(rax), ymm1) // ymm1 = alpha02 - + vfmadd231ps(ymm0, ymm10, ymm2) // ymm2 += alpha03 * ymm10 vfmadd231ps(ymm0, ymm11, ymm3) // ymm3 += alpha03 * ymm11 - + vbroadcastss(mem(0+1*6)*4(rax), ymm0) // ymm0 = alpha01 - + vfmadd231ps(ymm1, ymm8, ymm2) // ymm2 += alpha02 * ymm8 vfmadd231ps(ymm1, ymm9, ymm3) // ymm3 += alpha02 * ymm9 - + vbroadcastss(mem(0+0*6)*4(rax), ymm1) // ymm1 = (1/alpha00) - + vfmadd231ps(ymm0, ymm6, ymm2) // ymm2 += alpha01 * ymm6 vfmadd231ps(ymm0, ymm7, ymm3) // ymm3 += alpha01 * ymm7 - + vsubps(ymm2, ymm4, ymm4) // ymm4 -= ymm2 vsubps(ymm3, ymm5, ymm5) // ymm5 -= ymm3 - - vmulps(ymm4, ymm1, ymm4) // ymm4 *= (1/alpha00) - vmulps(ymm5, ymm1, ymm5) // ymm5 *= (1/alpha00) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm4, ymm4) // ymm4 *= (1/alpha00) + vmulps(ymm1, ymm5, ymm5) // ymm5 *= (1/alpha00) +#else + vdivps(ymm1, ymm4, ymm4) // ymm4 /= alpha00 + vdivps(ymm1, ymm5, ymm5) // ymm5 /= alpha00 +#endif + vmovups(ymm4, mem(rcx)) // store ( beta00..beta07 ) = ymm4 vmovups(ymm5, mem(rdx)) // store ( beta08..beta0F ) = ymm5 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - - - - - + + + + + mov(r8, rcx) // load address of c11 from r8 mov(r9, rdi) // load rs_c (in bytes) from r9 mov(r10, rsi) // load cs_c (in bytes) from r10 - + lea(mem(rcx, rsi, 8), rdx) // load address of c11 + 8*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c; - + // These are used in the macros below. lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - + + + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. jz(.SROWSTORED) // jump to row storage case - - - + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. jz(.SCOLSTORED) // jump to column storage case - - - + + + // if neither row- or column- // stored, use general case. label(.SGENSTORED) - - + + vmovaps(ymm4, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm6, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm8, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm10, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm12, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm14, ymm0) SGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c11 + 8*cs_c - - + + vmovaps(ymm5, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm7, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm9, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm11, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm13, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm15, ymm0) SGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.SDONE) - - - + + + label(.SROWSTORED) - - + + vmovups(ymm4, mem(rcx)) add(rdi, rcx) vmovups(ymm5, mem(rdx)) add(rdi, rdx) - + vmovups(ymm6, mem(rcx)) add(rdi, rcx) vmovups(ymm7, mem(rdx)) add(rdi, rdx) - + vmovups(ymm8, mem(rcx)) add(rdi, rcx) vmovups(ymm9, mem(rdx)) add(rdi, rdx) - + vmovups(ymm10, mem(rcx)) add(rdi, rcx) vmovups(ymm11, mem(rdx)) add(rdi, rdx) - + vmovups(ymm12, mem(rcx)) add(rdi, rcx) vmovups(ymm13, mem(rdx)) add(rdi, rdx) - + vmovups(ymm14, mem(rcx)) //add(rdi, rcx) vmovups(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.SDONE) - - - + + + label(.SCOLSTORED) - - + + vunpcklps(ymm6, ymm4, ymm0) vunpcklps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm6, ymm4, ymm0) vunpckhps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm14, ymm12, ymm0) vunpckhps(ymm14, ymm12, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) @@ -717,46 +751,46 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - + + vunpcklps(ymm7, ymm5, ymm0) vunpcklps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx)) // store ( gamma08..gamma38 ) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma09..gamma39 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma0C..gamma3C ) vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma0D..gamma3D ) - + vunpckhps(ymm7, ymm5, ymm0) vunpckhps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma0A..gamma3A ) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma0B..gamma3B ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma0E..gamma3E ) vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma0F..gamma3F ) - + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm15, ymm13, ymm0) vunpckhps(ymm15, ymm13, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovlpd(xmm0, mem(r14)) // store ( gamma48..gamma58 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma49..gamma59 ) vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma4A..gamma5A ) @@ -765,32 +799,34 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma4D..gamma5D ) vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma4E..gamma5E ) vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma4F..gamma5F ) - + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - + + + + label(.SDONE) - + vzeroupper() - + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a10] "m" (a10), // 2 - [b01] "m" (b01), // 3 - [beta] "m" (beta), // 4 - [alpha] "m" (alpha), // 5 - [a11] "m" (a11), // 6 - [b11] "m" (b11), // 7 - [c11] "m" (c11), // 8 - [rs_c] "m" (rs_c), // 9 - [cs_c] "m" (cs_c) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a10] "m" (a10), // 2 + [b01] "m" (b01), // 3 + [beta] "m" (beta), // 4 + [alpha] "m" (alpha), // 5 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -798,6 +834,8 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMMTRSM_UKR_FLUSH_CT( s ); } @@ -815,17 +853,19 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vmovhpd(xmm1, mem(rcx, r10, 1))*/ void bli_dgemmtrsm_u_haswell_asm_6x8 -( - dim_t k0, - double* restrict alpha, - double* restrict a10, - double* restrict a11, - double* restrict b01, - double* restrict b11, - double* restrict c11, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx -) + ( + dim_t m, + dim_t n, + dim_t k0, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) { //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); @@ -839,23 +879,25 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 double* beta = bli_dm1; - begin_asm() - + GEMMTRSM_UKR_SETUP_CT_ANY( d, 6, 8, true ); + + begin_asm() + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a10), rax) // load address of a. mov(var(b01), rbx) // load address of b. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(b11), rcx) // load address of b11 mov(imm(8), rdi) // set rs_b = PACKNR = 8 lea(mem(, rdi, 8), rdi) // rs_b *= sizeof(double) - + // NOTE: c11, rs_c, and cs_c aren't // needed for a while, but we load // them now to avoid stalling later. @@ -864,97 +906,99 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 lea(mem(, r9 , 8), r9) // rs_c *= sizeof(double) mov(var(k_left)0, r10) // load cs_c lea(mem(, r10, 8), r10) // cs_c *= sizeof(double) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 - prefetch(0, mem(rax, 76*8)) - + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -962,145 +1006,145 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + // ymm4..ymm15 = -a10 * b01 - - - - + + + + mov(var(alpha), rbx) // load address of alpha vbroadcastsd(mem(rbx), ymm3) // load alpha and duplicate - - - - + + + + mov(imm(1), rsi) // set cs_b = 1 lea(mem(, rsi, 8), rsi) // cs_b *= sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of b11 + 4*cs_b - + mov(rcx, r11) // save rcx = b11 for later mov(rdx, r14) // save rdx = b11+4*cs_b for later - - + + // b11 := alpha * b11 - a10 * b01 vfmsub231pd(mem(rcx), ymm3, ymm4) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm5) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm6) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm7) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm8) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm9) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm10) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm11) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm12) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm13) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm14) //add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm15) //add(rdi, rdx) - - - + + + // prefetch c11 - + #if 0 mov(r8, rcx) // load address of c11 from r8 // Note: r9 = rs_c * sizeof(double) - + lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c11 + 0*rs_c prefetch(0, mem(rcx, r9, 1, 7*8)) // prefetch c11 + 1*rs_c prefetch(0, mem(rcx, r9 , 2, 7*8)) // prefetch c11 + 2*rs_c @@ -1108,12 +1152,12 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 prefetch(0, mem(rdx, r9, 1, 7*8)) // prefetch c11 + 4*rs_c prefetch(0, mem(rdx, r9 , 2, 7*8)) // prefetch c11 + 5*rs_c #endif - - - - + + + + // trsm computation begins here - + // Note: contents of b11 are stored as // ymm4 ymm5 = ( beta00..03 ) ( beta04..07 ) // ymm6 ymm7 = ( beta10..13 ) ( beta14..17 ) @@ -1121,314 +1165,344 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 // ymm10 ymm11 = ( beta30..33 ) ( beta34..37 ) // ymm12 ymm13 = ( beta40..43 ) ( beta44..47 ) // ymm14 ymm15 = ( beta50..53 ) ( beta54..57 ) - - + + mov(var(a11), rax) // load address of a11 - + mov(r11, rcx) // recall address of b11 mov(r14, rdx) // recall address of b11+4*cs_b - + lea(mem(rcx, rdi, 4), rcx) // rcx = b11 + (6-1)*rs_b lea(mem(rcx, rdi, 1), rcx) lea(mem(rdx, rdi, 4), rdx) // rdx = b11 + (6-1)*rs_b + 4*cs_b lea(mem(rdx, rdi, 1), rdx) - - + + // iteration 0 ------------- - + vbroadcastsd(mem(5+5*6)*8(rax), ymm0) // ymm0 = (1/alpha55) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm14, ymm14) // ymm14 *= (1/alpha55) vmulpd(ymm0, ymm15, ymm15) // ymm15 *= (1/alpha55) - +#else + vdivpd(ymm0, ymm14, ymm14) // ymm14 /= alpha55 + vdivpd(ymm0, ymm15, ymm15) // ymm15 /= alpha55 +#endif + vmovupd(ymm14, mem(rcx)) // store ( beta50..beta53 ) = ymm14 vmovupd(ymm15, mem(rdx)) // store ( beta54..beta57 ) = ymm15 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 1 ------------- - + vbroadcastsd(mem(4+5*6)*8(rax), ymm0) // ymm0 = alpha45 vbroadcastsd(mem(4+4*6)*8(rax), ymm1) // ymm1 = (1/alpha44) - + vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha45 * ymm14 vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha45 * ymm15 - + vsubpd(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubpd(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - - vmulpd(ymm12, ymm1, ymm12) // ymm12 *= (1/alpha44) - vmulpd(ymm13, ymm1, ymm13) // ymm13 *= (1/alpha44) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulpd(ymm1, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivpd(ymm1, ymm12, ymm12) // ymm12 /= alpha44 + vdivpd(ymm1, ymm13, ymm13) // ymm13 /= alpha44 +#endif + vmovupd(ymm12, mem(rcx)) // store ( beta40..beta43 ) = ymm12 vmovupd(ymm13, mem(rdx)) // store ( beta44..beta47 ) = ymm13 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 2 ------------- - + vbroadcastsd(mem(3+5*6)*8(rax), ymm0) // ymm0 = alpha35 vbroadcastsd(mem(3+4*6)*8(rax), ymm1) // ymm1 = alpha34 - + vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha35 * ymm14 vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha35 * ymm15 - + vbroadcastsd(mem(3+3*6)*8(rax), ymm0) // ymm0 = (1/alpha33) - + vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha34 * ymm12 vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha34 * ymm13 - + vsubpd(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubpd(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - - vmulpd(ymm10, ymm0, ymm10) // ymm10 *= (1/alpha33) - vmulpd(ymm11, ymm0, ymm11) // ymm11 *= (1/alpha33) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulpd(ymm0, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivpd(ymm0, ymm10, ymm10) // ymm10 /= alpha33 + vdivpd(ymm0, ymm11, ymm11) // ymm11 /= alpha33 +#endif + vmovupd(ymm10, mem(rcx)) // store ( beta30..beta33 ) = ymm10 vmovupd(ymm11, mem(rdx)) // store ( beta34..beta37 ) = ymm11 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 3 ------------- - + vbroadcastsd(mem(2+5*6)*8(rax), ymm0) // ymm0 = alpha25 vbroadcastsd(mem(2+4*6)*8(rax), ymm1) // ymm1 = alpha24 - + vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha25 * ymm14 vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha25 * ymm15 - + vbroadcastsd(mem(2+3*6)*8(rax), ymm0) // ymm0 = alpha23 - + vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha24 * ymm12 vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha24 * ymm13 - + vbroadcastsd(mem(2+2*6)*8(rax), ymm1) // ymm1 = (1/alpha22) - + vfmadd231pd(ymm0, ymm10, ymm2) // ymm2 += alpha23 * ymm10 vfmadd231pd(ymm0, ymm11, ymm3) // ymm3 += alpha23 * ymm11 - + vsubpd(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubpd(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - - vmulpd(ymm8, ymm1, ymm8) // ymm8 *= (1/alpha33) - vmulpd(ymm9, ymm1, ymm9) // ymm9 *= (1/alpha33) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulpd(ymm1, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivpd(ymm1, ymm8, ymm8) // ymm8 /= alpha22 + vdivpd(ymm1, ymm9, ymm9) // ymm9 /= alpha22 +#endif + vmovupd(ymm8, mem(rcx)) // store ( beta20..beta23 ) = ymm8 vmovupd(ymm9, mem(rdx)) // store ( beta24..beta27 ) = ymm9 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 4 ------------- - + vbroadcastsd(mem(1+5*6)*8(rax), ymm0) // ymm0 = alpha15 vbroadcastsd(mem(1+4*6)*8(rax), ymm1) // ymm1 = alpha14 - + vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha15 * ymm14 vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha15 * ymm15 - + vbroadcastsd(mem(1+3*6)*8(rax), ymm0) // ymm0 = alpha13 - + vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha14 * ymm12 vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha14 * ymm13 - + vbroadcastsd(mem(1+2*6)*8(rax), ymm1) // ymm1 = alpha12 - + vfmadd231pd(ymm0, ymm10, ymm2) // ymm2 += alpha13 * ymm10 vfmadd231pd(ymm0, ymm11, ymm3) // ymm3 += alpha13 * ymm11 - + vbroadcastsd(mem(1+1*6)*8(rax), ymm0) // ymm4 = (1/alpha11) - + vfmadd231pd(ymm1, ymm8, ymm2) // ymm2 += alpha12 * ymm8 vfmadd231pd(ymm1, ymm9, ymm3) // ymm3 += alpha12 * ymm9 - + vsubpd(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubpd(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - - vmulpd(ymm6, ymm0, ymm6) // ymm6 *= (1/alpha44) - vmulpd(ymm7, ymm0, ymm7) // ymm7 *= (1/alpha44) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulpd(ymm0, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivpd(ymm0, ymm6, ymm6) // ymm6 /= alpha11 + vdivpd(ymm0, ymm7, ymm7) // ymm7 /= alpha11 +#endif + vmovupd(ymm6, mem(rcx)) // store ( beta10..beta13 ) = ymm6 vmovupd(ymm7, mem(rdx)) // store ( beta14..beta17 ) = ymm7 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - + // iteration 5 ------------- - + vbroadcastsd(mem(0+5*6)*8(rax), ymm0) // ymm0 = alpha05 vbroadcastsd(mem(0+4*6)*8(rax), ymm1) // ymm1 = alpha04 - + vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha05 * ymm14 vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha05 * ymm15 - + vbroadcastsd(mem(0+3*6)*8(rax), ymm0) // ymm0 = alpha03 - + vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha04 * ymm12 vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha04 * ymm13 - + vbroadcastsd(mem(0+2*6)*8(rax), ymm1) // ymm1 = alpha02 - + vfmadd231pd(ymm0, ymm10, ymm2) // ymm2 += alpha03 * ymm10 vfmadd231pd(ymm0, ymm11, ymm3) // ymm3 += alpha03 * ymm11 - + vbroadcastsd(mem(0+1*6)*8(rax), ymm0) // ymm0 = alpha01 - + vfmadd231pd(ymm1, ymm8, ymm2) // ymm2 += alpha02 * ymm8 vfmadd231pd(ymm1, ymm9, ymm3) // ymm3 += alpha02 * ymm9 - + vbroadcastsd(mem(0+0*6)*8(rax), ymm1) // ymm1 = (1/alpha00) - + vfmadd231pd(ymm0, ymm6, ymm2) // ymm2 += alpha01 * ymm6 vfmadd231pd(ymm0, ymm7, ymm3) // ymm3 += alpha01 * ymm7 - + vsubpd(ymm2, ymm4, ymm4) // ymm4 -= ymm2 vsubpd(ymm3, ymm5, ymm5) // ymm5 -= ymm3 - - vmulpd(ymm4, ymm1, ymm4) // ymm4 *= (1/alpha00) - vmulpd(ymm5, ymm1, ymm5) // ymm5 *= (1/alpha00) - + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm4, ymm4) // ymm4 *= (1/alpha00) + vmulpd(ymm1, ymm5, ymm5) // ymm5 *= (1/alpha00) +#else + vdivpd(ymm1, ymm4, ymm4) // ymm4 /= alpha00 + vdivpd(ymm1, ymm5, ymm5) // ymm5 /= alpha00 +#endif + vmovupd(ymm4, mem(rcx)) // store ( beta00..beta03 ) = ymm4 vmovupd(ymm5, mem(rdx)) // store ( beta04..beta07 ) = ymm5 sub(rdi, rcx) // rcx -= rs_b sub(rdi, rdx) // rdx -= rs_b - - - - + + + + mov(r8, rcx) // load address of c11 from r8 mov(r9, rdi) // load rs_c (in bytes) from r9 mov(r10, rsi) // load cs_c (in bytes) from r10 - + lea(mem(rcx, rsi, 4), rdx) // load address of c11 + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c; - + // These are used in the macros below. lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - + + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORED) // jump to row storage case - - - + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - + + + // if neither row- or column- // stored, use general case. label(.DGENSTORED) - - + + vmovapd(ymm4, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm6, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm8, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm10, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm12, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm14, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c11 + 4*cs_c - - + + vmovapd(ymm5, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm7, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm9, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm11, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm13, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm15, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + jmp(.DDONE) - - - + + + label(.DROWSTORED) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.DDONE) - - - + + + label(.DCOLSTORED) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1437,27 +1511,27 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm3, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1466,34 +1540,35 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm3, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - - - + + + + + label(.DDONE) - + vzeroupper() - + + end_asm( : // output operands (none) @@ -1518,6 +1593,8 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMMTRSM_UKR_FLUSH_CT( d ); } diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c new file mode 100644 index 0000000000..1820277d5a --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -0,0 +1,1925 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. + + NOTE: These kernels implicitly support column-oriented IO, implemented + via an a high-level transposition of the entire operation. A and B will + effectively remain row- and column-stored, respectively, but C will then + effectively appear column-stored. Thus, this kernel may be used for both + rrc and crc cases. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9) + // ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + // xmm10[0:1] = sum(ymm10) sum(ymm11) + // xmm12[0:1] = sum(ymm12) sum(ymm13) + // xmm14[0:1] = sum(ymm14) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c new file mode 100644 index 0000000000..e720e7da1c --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -0,0 +1,2385 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. + + NOTE: These kernels implicitly support column-oriented IO, implemented + via an a high-level transposition of the entire operation. A and B will + effectively remain row- and column-stored, respectively, but C will then + effectively appear column-stored. Thus, this kernel may be used for both + rrc and crc cases. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + +#if 1 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + // These kernels don't make any attempt to optimize the cases of + // inflated MR blocksizes because they don't benefit from the + // load balancing that the "rv" kernels do. That is, if m0 = 7, + // there is no benefit to executing that case as 4x8n followed + // by 3x8n because 4x8n isn't implemented, and more generally + // because these kernels are implemented as loops over their + // true blocksizes, which are MR=3 NR=4. + if ( m0 == 7 ) + { + mr1 = 6; mr2 = 1; + ker_fp1 = bli_dgemmsup_rd_haswell_asm_6x8n; + ker_fp2 = bli_dgemmsup_rd_haswell_asm_1x8n; + } + else if ( m0 == 8 ) + { + mr1 = 6; mr2 = 2; + ker_fp1 = bli_dgemmsup_rd_haswell_asm_6x8n; + ker_fp2 = bli_dgemmsup_rd_haswell_asm_2x8n; + } + else // if ( m0 == 9 ) + { + mr1 = 6; mr2 = 3; + ker_fp1 = bli_dgemmsup_rd_haswell_asm_6x8n; + ker_fp2 = bli_dgemmsup_rd_haswell_asm_3x8n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x8n + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8n + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + #if 1 + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8n + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + mov(var(a), rdx) // load address of a + mov(var(b), r14) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c + lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(rdx, rsi, 1), rdx) // rax = a + 3*ii*rs_a; + + + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9) + // ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + add(imm(3), r9) // ii += 3; + cmp(imm(6), r9) // compare ii to 6 + jl(.DLOOP3X4I) // if ii < 6, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_6x1 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9) + // ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_3x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_3x1 + //bli_dgemmsup_r_haswell_ref_3x1 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_2x1 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_1x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x1 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_ddotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); + #endif + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c new file mode 100644 index 0000000000..f764bc613e --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c @@ -0,0 +1,3215 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. + + NOTE: These kernels implicitly support column-oriented IO, implemented + via an a high-level transposition of the entire operation. A and B will + effectively remain row- and column-stored, respectively, but C will then + effectively appear column-stored. Thus, this kernel may be used for both + rrc and crc cases. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x16m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 16; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rd_haswell_asm_6x12m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rd_haswell_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 0 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 16 + jl(.SLOOP3X4J) // if jj < 16, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 16; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x16 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x16 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_haswell_asm_6x12m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(12), r15) // compare jj to 12 + jl(.SLOOP3X4J) // if jj < 12, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 12; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x12 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x12 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.SLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x4 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*4)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*4)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*4)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovss(mem(rax, r8, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovss(mem(rax, r13, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rax, r8, 4), xmm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovss(mem(rax, r15, 1), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm8 ) + + vhaddps( ymm11, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm10 ) + + vhaddps( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm12 ) + + vhaddps( ymm15, ymm14, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm14 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + // xmm10[0:1] = sum(ymm10) sum(ymm11) + // xmm12[0:1] = sum(ymm12) sum(ymm13) + // xmm14[0:1] = sum(ymm14) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_sgemmsup_rd_haswell_asm_3x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c new file mode 100644 index 0000000000..1fe862a8d1 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c @@ -0,0 +1,2416 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. + + NOTE: These kernels implicitly support column-oriented IO, implemented + via an a high-level transposition of the entire operation. A and B will + effectively remain row- and column-stored, respectively, but C will then + effectively appear column-stored. Thus, this kernel may be used for both + rrc and crc cases. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + +#if 1 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + // These kernels don't make any attempt to optimize the cases of + // inflated MR blocksizes because they don't benefit from the + // load balancing that the "rv" kernels do. That is, if m0 = 7, + // there is no benefit to executing that case as 4x16n followed + // by 3x16n because 4x16n isn't implemented, and more generally + // because these kernels are implemented as loops over their + // true blocksizes, which are MR=3 NR=4. + if ( m0 == 7 ) + { + mr1 = 6; mr2 = 1; + ker_fp1 = bli_sgemmsup_rd_haswell_asm_6x16n; + ker_fp2 = bli_sgemmsup_rd_haswell_asm_1x16n; + } + else if ( m0 == 8 ) + { + mr1 = 6; mr2 = 2; + ker_fp1 = bli_sgemmsup_rd_haswell_asm_6x16n; + ker_fp2 = bli_sgemmsup_rd_haswell_asm_2x16n; + } + else // if ( m0 == 9 ) + { + mr1 = 6; mr2 = 3; + ker_fp1 = bli_sgemmsup_rd_haswell_asm_6x16n; + ker_fp2 = bli_sgemmsup_rd_haswell_asm_3x16n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_sgemmsup_rd_haswell_asm_3x16n + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x16n + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + #if 1 + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x16n + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.SLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + mov(var(b), r14) // load address of b + mov(var(c), r12) // load address of c + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c + lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(rdx, rsi, 1), rdx) // rax = a + 3*ii*rs_a; + + + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(32*4), r10) // r10 += 32*rs_b = 32*4; +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + + + add(imm(3), r9) // ii += 3; + cmp(imm(6), r9) // compare jj to 6 + jl(.SLOOP3X4I) // if ii < 6, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_6x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_6x1 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rd_haswell_asm_3x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(32*4), r10) // r10 += 32*rs_b = 32*4; +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_3x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_3x1 + //bli_sgemmsup_r_haswell_ref_3x1 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rd_haswell_asm_2x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(32*4), r10) // r10 += 32*rs_b = 32*4; +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_2x1 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rd_haswell_asm_1x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c +#endif + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(32*4), r10) // r10 += 32*rs_b = 32*4; +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the n dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_1x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x1 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sdotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); + #endif + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c new file mode 100644 index 0000000000..1637e97667 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -0,0 +1,3054 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_6x6m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + dim_t ps_a0 = bli_auxinfo_ps_a( data ); + + if ( ps_a0 == 6 * rs_a0 ) + { + // Since A is not packed, we can use one gemv. + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + else + { + const dim_t mr = 6; + + // Since A is packed into row panels, we must use a loop over + // gemv. + dim_t m_iter = ( m0 + mr - 1 ) / mr; + dim_t m_left = m0 % mr; + + double* restrict ai_ii = ai; + double* restrict cij_ii = cij; + + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) + { + dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) + ? mr : m_left ); + + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, + beta, cij_ii, rs_c0, cntx, NULL + ); + cij_ii += mr*rs_c0; ai_ii += ps_a0; + } + } +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X8I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + //double* restrict ai = a + i_edge*rs_a; + //double* restrict ai = a + ( i_edge / 6 ) * ps_a; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x8; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x8; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x8; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x8, + bli_dgemmsup_rv_haswell_asm_2x8, + bli_dgemmsup_rv_haswell_asm_3x8, + bli_dgemmsup_rv_haswell_asm_4x8, + bli_dgemmsup_rv_haswell_asm_5x8 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_dgemmsup_rv_haswell_asm_6x6m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower + vxorpd(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us + vxorpd(ymm5, ymm5, ymm5) // down. + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X8I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 6; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + //double* restrict ai = a + i_edge*rs_a; + //double* restrict ai = a + ( i_edge / 6 ) * ps_a; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x6; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x6; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x6; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x6, + bli_dgemmsup_rv_haswell_asm_2x6, + bli_dgemmsup_rv_haswell_asm_3x6, + bli_dgemmsup_rv_haswell_asm_4x6, + bli_dgemmsup_rv_haswell_asm_5x6 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_dgemmsup_rv_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + //double* restrict ai = a + i_edge*rs_a; + //double* restrict ai = a + ( i_edge / 6 ) * ps_a; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x4; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x4; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x4; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x4, + bli_dgemmsup_rv_haswell_asm_2x4, + bli_dgemmsup_rv_haswell_asm_3x4, + bli_dgemmsup_rv_haswell_asm_4x4, + bli_dgemmsup_rv_haswell_asm_5x4 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_dgemmsup_rv_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(xmm4, xmm4, xmm4) + vxorpd(xmm6, xmm6, xmm6) + vxorpd(xmm8, xmm8, xmm8) + vxorpd(xmm10, xmm10, xmm10) + vxorpd(xmm12, xmm12, xmm12) + vxorpd(xmm14, xmm14, xmm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) + vmovupd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X2I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + //double* restrict ai = a + i_edge*rs_a; + //double* restrict ai = a + ( i_edge / 6 ) * ps_a; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x2; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x2; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x2; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x2; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x2; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x2; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x2, + bli_dgemmsup_rv_haswell_asm_2x2, + bli_dgemmsup_rv_haswell_asm_3x2, + bli_dgemmsup_rv_haswell_asm_4x2, + bli_dgemmsup_rv_haswell_asm_5x2 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c new file mode 100644 index 0000000000..5ecef06e8b --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c @@ -0,0 +1,4114 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the m dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + +#if 1 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m0 == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8n; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x8n; + } + else if ( m0 == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8n; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x8n; + } + else // if ( m0 == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8n; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x8n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x8n, + bli_dgemmsup_rv_haswell_asm_2x8n, + bli_dgemmsup_rv_haswell_asm_3x8n, + bli_dgemmsup_rv_haswell_asm_4x8n, + bli_dgemmsup_rv_haswell_asm_5x8n + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b8 = ps_b * sizeof( double ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b8), rbx) // load ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 + + dec(r11) // jj -= 1; + jne(.DLOOP6X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b8] "m" (ps_b8), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + //double* restrict bj = b + j_edge*cs_b; + //double* restrict bj = b + ( j_edge / 8 ) * ps_b; + double* restrict bj = b + n_iter * ps_b; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_6x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_6x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_6x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_6x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_dgemmsup_rv_haswell_asm_5x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b8 = ps_b * sizeof( double ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm13, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm13, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b8), rbx) // load ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 + + dec(r11) // jj -= 1; + jne(.DLOOP6X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b8] "m" (ps_b8), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 5; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + //double* restrict bj = b + j_edge*cs_b; + //double* restrict bj = b + ( j_edge / 8 ) * ps_b; + double* restrict bj = b + n_iter * ps_b; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_5x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_5x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_5x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_5x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_dgemmsup_rv_haswell_asm_4x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b8 = ps_b * sizeof( double ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP4X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b8), rbx) // load ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 + + dec(r11) // jj -= 1; + jne(.DLOOP4X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b8] "m" (ps_b8), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 4; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + //double* restrict bj = b + j_edge*cs_b; + //double* restrict bj = b + ( j_edge / 8 ) * ps_b; + double* restrict bj = b + n_iter * ps_b; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_4x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_4x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_4x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rv_haswell_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b8 = ps_b * sizeof( double ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP4X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b8), rbx) // load ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 + + dec(r11) // jj -= 1; + jne(.DLOOP4X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b8] "m" (ps_b8), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + //double* restrict bj = b + j_edge*cs_b; + //double* restrict bj = b + ( j_edge / 8 ) * ps_b; + double* restrict bj = b + n_iter * ps_b; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_3x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_3x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rv_haswell_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b8 = ps_b * sizeof( double ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP2X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b8), rbx) // load ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 + + dec(r11) // jj -= 1; + jne(.DLOOP2X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b8] "m" (ps_b8), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + //double* restrict bj = b + j_edge*cs_b; + //double* restrict bj = b + ( j_edge / 8 ) * ps_b; + double* restrict bj = b + n_iter * ps_b; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_2x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_2x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rv_haswell_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b8 = ps_b * sizeof( double ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP1X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm5, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vmovupd(ymm5, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b8), rbx) // load ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 + + dec(r11) // jj -= 1; + jne(.DLOOP1X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b8] "m" (ps_b8), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + //double* restrict bj = b + j_edge*cs_b; + //double* restrict bj = b + ( j_edge / 8 ) * ps_b; + double* restrict bj = b + n_iter * ps_b; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_1x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_1x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_ddotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); +#endif + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c new file mode 100644 index 0000000000..426e5157e1 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c @@ -0,0 +1,4748 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x16m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 16; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_6x12m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_6x6m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + dim_t ps_a0 = bli_auxinfo_ps_a( data ); + + if ( ps_a0 == 6 * rs_a0 ) + { + // Since A is not packed, we can use one gemv. + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + else + { + const dim_t mr = 6; + + // Since A is packed into row panels, we must use a loop over + // gemv. + dim_t m_iter = ( m0 + mr - 1 ) / mr; + dim_t m_left = m0 % mr; + + float* restrict ai_ii = ai; + float* restrict cij_ii = cij; + + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) + { + dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) + ? mr : m_left ); + + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, + beta, cij_ii, rs_c0, cntx, NULL + ); + cij_ii += mr*rs_c0; ai_ii += ps_a0; + } + } + } + return; +#endif + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,15*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,15*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 15*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1,15*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2,15*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm13) + vmovups(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm15) + vmovups(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + vmovups(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 16; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + //float* restrict ai = a + i_edge*rs_a; + //float* restrict ai = a + ( i_edge / 6 ) * ps_a; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x16; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x16; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x16; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x16, + bli_sgemmsup_rv_haswell_asm_2x16, + bli_sgemmsup_rv_haswell_asm_3x16, + bli_sgemmsup_rv_haswell_asm_4x16, + bli_sgemmsup_rv_haswell_asm_5x16 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_haswell_asm_6x12m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 11*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,11*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,11*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 11*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1,11*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2,11*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(xmm0, xmm7, xmm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(xmm0, xmm9, xmm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(xmm0, xmm11, xmm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(xmm0, xmm13, xmm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm7) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm9) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm11) + vmovups(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm13) + vmovups(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm15) + vmovups(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm15, ymm13, ymm0) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(ymm15, ymm13, ymm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + vmovups(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm15, ymm13, ymm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(ymm15, ymm13, ymm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 12; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + //float* restrict ai = a + i_edge*rs_a; + //float* restrict ai = a + ( i_edge / 6 ) * ps_a; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x16; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x16; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x16; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x12, + bli_sgemmsup_rv_haswell_asm_2x12, + bli_sgemmsup_rv_haswell_asm_3x12, + bli_sgemmsup_rv_haswell_asm_4x12, + bli_sgemmsup_rv_haswell_asm_5x12 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + //float* restrict ai = a + i_edge*rs_a; + //float* restrict ai = a + ( i_edge / 6 ) * ps_a; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x8; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x8; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x8; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x8, + bli_sgemmsup_rv_haswell_asm_2x8, + bli_sgemmsup_rv_haswell_asm_3x8, + bli_sgemmsup_rv_haswell_asm_4x8, + bli_sgemmsup_rv_haswell_asm_5x8 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_haswell_asm_6x6m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm7) + vmovsd(xmm7, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm9) + vmovsd(xmm9, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm11) + vmovsd(xmm11, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm13) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm12) + vmovups(xmm12, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm13) + vmovsd(xmm13, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm14, xmm15) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm14) + vmovups(xmm14, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm15) + vmovsd(xmm15, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vmovups(xmm6, mem(rcx, 0*4)) + vmovsd(xmm7, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vmovups(xmm8, mem(rcx, 0*4)) + vmovsd(xmm9, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vmovups(xmm10, mem(rcx, 0*4)) + vmovsd(xmm11, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm13) + vmovups(xmm12, mem(rcx, 0*4)) + vmovsd(xmm13, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm14, xmm15) + vmovups(xmm14, mem(rcx, 0*4)) + vmovsd(xmm15, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 6; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + //float* restrict ai = a + i_edge*rs_a; + //float* restrict ai = a + ( i_edge / 6 ) * ps_a; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x6; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x6; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x6; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x6, + bli_sgemmsup_rv_haswell_asm_2x6, + bli_sgemmsup_rv_haswell_asm_3x6, + bli_sgemmsup_rv_haswell_asm_4x6, + bli_sgemmsup_rv_haswell_asm_5x6 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovups(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + vmovups(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(xmm14, xmm12, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(ymm14, ymm12, ymm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + //float* restrict ai = a + i_edge*rs_a; + //float* restrict ai = a + ( i_edge / 6 ) * ps_a; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x4; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x4; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x4; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x4, + bli_sgemmsup_rv_haswell_asm_2x4, + bli_sgemmsup_rv_haswell_asm_3x4, + bli_sgemmsup_rv_haswell_asm_4x4, + bli_sgemmsup_rv_haswell_asm_5x4 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 1*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovsd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + vmovsd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + //float* restrict ai = a + i_edge*rs_a; + //float* restrict ai = a + ( i_edge / 6 ) * ps_a; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x16; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x16; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x16; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x2, + bli_sgemmsup_rv_haswell_asm_2x2, + bli_sgemmsup_rv_haswell_asm_3x2, + bli_sgemmsup_rv_haswell_asm_4x2, + bli_sgemmsup_rv_haswell_asm_5x2 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c new file mode 100644 index 0000000000..7463707cc9 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c @@ -0,0 +1,4882 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the m dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + +#if 1 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m0 == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16n; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_3x16n; + } + else if ( m0 == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16n; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_4x16n; + } + else // if ( m0 == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_haswell_asm_4x16n; + ker_fp2 = bli_sgemmsup_rv_haswell_asm_5x16n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_haswell_asm_1x16n, + bli_sgemmsup_rv_haswell_asm_2x16n, + bli_sgemmsup_rv_haswell_asm_3x16n, + bli_sgemmsup_rv_haswell_asm_4x16n, + bli_sgemmsup_rv_haswell_asm_5x16n + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,15*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,15*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 15*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1,15*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2,15*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b4), rdx) // load ps_b4 + lea(mem(rbx, rdx, 1), rdx) // rdx = a + ps_b4 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm13) + vmovups(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm15) + vmovups(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + vmovups(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rsi, 8), r12) // + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 16*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X8J) // iterate again if jj != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + //float* restrict bj = b + j_edge*cs_b; + //float* restrict bj = b + ( j_edge / 8 ) * ps_b; + float* restrict bj = b + n_iter * ps_b; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_6x12 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_6x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_6x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_6x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_6x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref_6x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rv_haswell_asm_5x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,15*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,15*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 15*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1,15*4)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 4*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 4*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 4*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 4*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 4*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 4*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 4*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 4*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 4*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b4), rdx) // load ps_b4 + lea(mem(rbx, rdx, 1), rdx) // rdx = a + ps_b4 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm13) + vmovups(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rsi, 8), r12) // + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 16*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X8J) // iterate again if jj != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 5; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + //float* restrict bj = b + j_edge*cs_b; + //float* restrict bj = b + ( j_edge / 8 ) * ps_b; + float* restrict bj = b + n_iter * ps_b; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_5x12 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_5x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_5x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_5x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_5x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref_5x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rv_haswell_asm_4x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,15*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,15*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 15*4)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 3*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 3*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 3*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 3*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 3*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 3*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 3*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 3*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 3*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 3*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b4), rdx) // load ps_b4 + lea(mem(rbx, rdx, 1), rdx) // rdx = a + ps_b4 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rsi, 8), r12) // + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 16*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X8J) // iterate again if jj != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 4; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + //float* restrict bj = b + j_edge*cs_b; + //float* restrict bj = b + ( j_edge / 8 ) * ps_b; + float* restrict bj = b + n_iter * ps_b; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_4x12 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_4x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_4x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_4x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_4x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref_4x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rv_haswell_asm_3x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,15*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,15*4)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 2*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 2*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 2*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 2*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 2*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 2*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 2*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 2*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 2*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 2*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b4), rdx) // load ps_b4 + lea(mem(rbx, rdx, 1), rdx) // rdx = a + ps_b4 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm9, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm9, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rsi, 8), r12) // + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 16*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X8J) // iterate again if jj != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + //float* restrict bj = b + j_edge*cs_b; + //float* restrict bj = b + ( j_edge / 8 ) * ps_b; + float* restrict bj = b + n_iter * ps_b; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_3x12 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_3x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_3x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref_3x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rv_haswell_asm_2x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,15*4)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 1*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 1*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 1*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 1*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 1*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 1*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 1*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 1*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 1*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 1*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b4), rdx) // load ps_b4 + lea(mem(rbx, rdx, 1), rdx) // rdx = a + ps_b4 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rsi, 8), r12) // + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 16*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X8J) // iterate again if jj != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + //float* restrict bj = b + j_edge*cs_b; + //float* restrict bj = b + ( j_edge / 8 ) * ps_b; + float* restrict bj = b + n_iter * ps_b; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_2x12 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_2x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref_2x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + +void bli_sgemmsup_rv_haswell_asm_1x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 15*4)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 0*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 0*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 0*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 0*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 0*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 0*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 0*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 0*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 0*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*4)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rcx, 1, 0*4)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + +#if 1 + mov(var(ps_b4), rdx) // load ps_b4 + lea(mem(rbx, rdx, 1), rdx) // rdx = a + ps_b4 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vmovups(ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vmovups(ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + + lea(mem(r12, rsi, 8), r12) // + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 16*cs_c + + //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X8J) // iterate again if jj != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + //float* restrict bj = b + j_edge*cs_b; + //float* restrict bj = b + ( j_edge / 8 ) * ps_b; + float* restrict bj = b + n_iter * ps_b; + + if ( 12 <= n_left ) + { + const dim_t nr_cur = 12; + + bli_sgemmsup_rv_haswell_asm_1x12 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_sgemmsup_rv_haswell_asm_1x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_haswell_ref_1x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + } +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c new file mode 100644 index 0000000000..69d543a99d --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel. +// However, at least one other subconfiguration (zen) uses this kernel set, so +// we need to be able to call a set of "?x1" kernels that we know will actually +// exist regardless of which subconfiguration these kernels were used by. Thus, +// the compromise employed here is to inline the reference kernel so it gets +// compiled as part of the haswell kernel set, and hence can unconditionally be +// called by other kernels within that kernel set. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mdim ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < mdim; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + /* for ( dim_t j = 0; j < 1; ++j ) */ \ + { \ + ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \ + ctype* restrict bj = b /*[ j*cs_b ]*/ ; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( double, d, gemmsup_r_haswell_ref_6x1, 6 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_5x1, 5 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_4x1, 4 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_3x1, 3 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_2x1, 2 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_1x1, 1 ) + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c new file mode 100644 index 0000000000..457ef9f22d --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c @@ -0,0 +1,1698 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + // ymm6 + // ymm8 + // ymm10 + // ymm12 + // ymm14 + + vhaddpd( ymm4, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm6, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm8, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm10, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm12, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm14, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + // xmm8[0] = sum(ymm8) + // xmm10[0] = sum(ymm10) + // xmm12[0] = sum(ymm12) + // xmm14[0] = sum(ymm14) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm12) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm14) + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm8, ymm8, ymm8) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + // ymm6 + // ymm8 + + vhaddpd( ymm4, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm6, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm8, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + // xmm8[0] = sum(ymm8) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + // ymm6 + + vhaddpd( ymm4, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm6, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + + vhaddpd( ymm4, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + // xmm4[0] = sum(ymm4) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c new file mode 100644 index 0000000000..af498eb0ee --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c @@ -0,0 +1,1794 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + // xmm10[0:1] = sum(ymm10) sum(ymm11) + // xmm12[0:1] = sum(ymm12) sum(ymm13) + // xmm14[0:1] = sum(ymm14) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c new file mode 100644 index 0000000000..516bfced54 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -0,0 +1,1438 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // xmm5[0:3] = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // xmm6[0:3] = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + //lea(mem(r12), rcx) // rcx = c; + //lea(mem(r14), rax) // rax = a; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // xmm5[0:3] = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + //lea(mem(r12), rcx) // rcx = c; + //lea(mem(r14), rax) // rax = a; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + //prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c new file mode 100644 index 0000000000..571444bed3 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c @@ -0,0 +1,1617 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_6x1 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c new file mode 100644 index 0000000000..eb1118196b --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c @@ -0,0 +1,2496 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) + vmovupd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // r13 = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm12, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(xmm12, xmm0) + + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm4, xmm3, xmm0) + + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vmovlpd(xmm4, mem(rcx )) + vmovhpd(xmm4, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c new file mode 100644 index 0000000000..bdcf833e3d --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c @@ -0,0 +1,2600 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c new file mode 100644 index 0000000000..9da1e7b838 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c @@ -0,0 +1,3090 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm13, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm13, ymm0) + + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + //vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + //vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm5, xmm3, xmm0) + + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vmovupd(xmm5, xmm0) + + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c new file mode 100644 index 0000000000..a6c8f0e43d --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c @@ -0,0 +1,3260 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8, bli_dgemmsup_rv_haswell_asm_6x4, bli_dgemmsup_rv_haswell_asm_6x2, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8, bli_dgemmsup_rv_haswell_asm_4x4, bli_dgemmsup_rv_haswell_asm_4x2, bli_dgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8, bli_dgemmsup_rv_haswell_asm_2x4, bli_dgemmsup_rv_haswell_asm_2x2, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8, bli_dgemmsup_rv_haswell_asm_1x4, bli_dgemmsup_rv_haswell_asm_1x2, bli_dgemmsup_r_haswell_ref_1x1 }, +}; + + +void bli_dgemmsup_rv_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 8 ) + { +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm13, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm13, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi","rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi","rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm5, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vmovupd(ymm5, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi","rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c new file mode 100644 index 0000000000..87ef7309b3 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -0,0 +1,4566 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +#if 0 +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref_3x1 }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4 + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2 + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c new file mode 100644 index 0000000000..fe61fbc313 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c @@ -0,0 +1,11048 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8, bli_dgemmsup_rv_haswell_asm_6x4, bli_dgemmsup_rv_haswell_asm_6x2, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8, bli_dgemmsup_rv_haswell_asm_4x4, bli_dgemmsup_rv_haswell_asm_4x2, bli_dgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8, bli_dgemmsup_rv_haswell_asm_2x4, bli_dgemmsup_rv_haswell_asm_2x2, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8, bli_dgemmsup_rv_haswell_asm_1x4, bli_dgemmsup_rv_haswell_asm_1x2, bli_dgemmsup_r_haswell_ref_1x1 }, +}; + + +void bli_dgemmsup_rv_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 8 ) + { +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm13, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm13, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm5, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovupd(ymm5, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + //vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + //vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(xmm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(xmm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + //vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + //vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + //vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + //vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + //vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(xmm13, xmm3, xmm0) + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + //vmovlpd(xmm1, mem(rdx, rsi, 2)) + //vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(xmm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm13, ymm0) + + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + //vmovlpd(xmm1, mem(rdx, rsi, 2)) + //vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + //vextractf128(imm(0x1), ymm9, xmm14) + //vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(xmm9, mem(rcx, rsi, 2)) + //vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + //vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + //vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + //vextractf128(imm(0x1), ymm9, xmm14) + //vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(xmm9, mem(rcx, rsi, 2)) + //vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + +#if 1 + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + //vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + //vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + //vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(xmm5, xmm3, xmm0) + + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + //vmovlpd(xmm1, mem(rcx, rsi, 2)) + //vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovupd(xmm5, xmm0) + + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + //vmovlpd(xmm1, mem(rcx, rsi, 2)) + //vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + +#if 0 + lea(mem(rax, r9, 8), rdx) // use rdx for prefetching b. + lea(mem(rdx, r9, 8), rdx) // rdx = b + 16*rs_b; +#else + #if 1 + mov(r9, rsi) // rsi = rs_b; + sal(imm(5), rsi) // rsi = 16*rs_b; + lea(mem(rax, rsi, 1), rdx) // rdx = b + 16*rs_b; + #endif +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // r13 = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm12, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) +#else + vmovupd(xmm12, xmm0) + + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + //vextractf128(imm(0x1), ymm8, xmm14) + //vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + //vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + //vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + //vmovupd(xmm8, mem(rcx, rsi, 2)) + //vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + //vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + //vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + //vextractf128(imm(0x1), ymm8, xmm14) + //vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + //vmovupd(xmm8, mem(rcx, rsi, 2)) + //vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm4, xmm3, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovlpd(xmm4, mem(rcx)) + vmovhpd(xmm4, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +// ----------------------------------------------------------------------------- + +// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel. +// However, at least one other subconfiguration (zen) uses this kernel set, so +// we need to be able to call a set of "?x1" kernels that we know will actually +// exist regardless of which subconfiguration these kernels were used by. Thus, +// the compromise employed here is to inline the reference kernel so it gets +// compiled as part of the haswell kernel set, and hence can unconditionally be +// called by other kernels within that kernel set. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mdim ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < mdim; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + /* for ( dim_t j = 0; j < 1; ++j ) */ \ + { \ + ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \ + ctype* restrict bj = b /*[ j*cs_b ]*/ ; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( double, d, gemmsup_r_haswell_ref_6x1, 6 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_5x1, 5 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_4x1, 4 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_3x1, 3 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_2x1, 2 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_1x1, 1 ) + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c new file mode 100644 index 0000000000..c5addd9cf2 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -0,0 +1,5249 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8, bli_dgemmsup_rd_haswell_asm_6x4, bli_dgemmsup_rd_haswell_asm_6x2, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8, bli_dgemmsup_rd_haswell_asm_3x4, bli_dgemmsup_rd_haswell_asm_3x2, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8, bli_dgemmsup_rd_haswell_asm_2x4, bli_dgemmsup_rd_haswell_asm_2x2, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8, bli_dgemmsup_rd_haswell_asm_1x4, bli_dgemmsup_rd_haswell_asm_1x2, bli_dgemmsup_r_haswell_ref } +}; + + +void bli_dgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 8 ) + { + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + +#if 1 + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c; + lea(mem(r10, rsi, 1), rdx) // rdx = c_jj + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(r12, rsi, 1), r12) // rax = a + 3*ii*rs_a; + + + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + vzeroall() // zero all xmm/ymm registers. + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(rdx, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a_ii; +#endif + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + +#if 1 + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning +#endif + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r10, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r10, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r10, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c; + lea(mem(r10, rsi, 1), rcx) // rcx = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(r12, rsi, 1), rax) // rax = a + 3*ii*rs_a; + + lea(mem( , r14, 1), rbx) // rbx = b; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c new file mode 100644 index 0000000000..55ae6d0f91 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -0,0 +1,5543 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji new file mode 100644 index 0000000000..c1cb372142 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji @@ -0,0 +1,5628 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij new file mode 100644 index 0000000000..fd1c2ae657 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij @@ -0,0 +1,5634 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a_ii; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c new file mode 100644 index 0000000000..a23764f8d4 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -0,0 +1,5836 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8n, bli_dgemmsup_rd_haswell_asm_6x4n, bli_dgemmsup_rd_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8n, bli_dgemmsup_rd_haswell_asm_3x4n, bli_dgemmsup_rd_haswell_asm_3x2n, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8n, bli_dgemmsup_rd_haswell_asm_2x4n, bli_dgemmsup_rd_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8n, bli_dgemmsup_rd_haswell_asm_1x4n, bli_dgemmsup_rd_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); return; +#endif + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { +#if 0 + const dim_t mr_cur = 1; + + //bli_dgemmsup_r_haswell_ref + bli_dgemmsup_rd_haswell_asm_1x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + mov(var(b), r14) // load address of b + mov(var(c), r12) // load address of c + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c + lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(rdx, rsi, 1), rdx) // rax = a + 3*ii*rs_a; + + + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_6x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_3x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_3x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_2x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_1x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_1x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_ddotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_6x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c new file mode 100644 index 0000000000..dad5458b9a --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c @@ -0,0 +1,224 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + +// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel. +// However, at least one other subconfiguration (zen) uses this kernel set, so +// we need to be able to call a set of "?x1" kernels that we know will actually +// exist regardless of which subconfiguration these kernels were used by. Thus, +// the compromise employed here is to inline the reference kernel so it gets +// compiled as part of the haswell kernel set, and hence can unconditionally be +// called by other kernels within that kernel set. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mdim ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < mdim; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + /* for ( dim_t j = 0; j < 1; ++j ) */ \ + { \ + ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \ + ctype* restrict bj = b /*[ j*cs_b ]*/ ; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( float, s, gemmsup_r_haswell_ref_6x1, 6 ) +GENTFUNC( float, s, gemmsup_r_haswell_ref_5x1, 5 ) +GENTFUNC( float, s, gemmsup_r_haswell_ref_4x1, 4 ) +GENTFUNC( float, s, gemmsup_r_haswell_ref_3x1, 3 ) +GENTFUNC( float, s, gemmsup_r_haswell_ref_2x1, 2 ) +GENTFUNC( float, s, gemmsup_r_haswell_ref_1x1, 1 ) + +// ----------------------------------------------------------------------------- + +#if 0 +// Temporary definition of general-purpose sup kernel. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( float, s, gemmsup_r_haswell_ref ) +#endif diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c new file mode 100644 index 0000000000..1eb8d926c9 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c @@ -0,0 +1,1725 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 0*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 0*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 0*4)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 0*4)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 0*4)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 0*4)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovss(mem(rax, r8, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm8) + + vmovss(mem(rax, r13, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rax, r8, 4), xmm3) + vfmadd231ps(ymm0, ymm3, ymm12) + + vmovss(mem(rax, r15, 1), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 + // ymm6 + // ymm8 + // ymm10 + // ymm12 + // ymm14 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm8 ) + + vhaddps( ymm11, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm10 ) + + vhaddps( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm12 ) + + vhaddps( ymm15, ymm14, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm14 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + // xmm8[0] = sum(ymm8) + // xmm10[0] = sum(ymm10) + // xmm12[0] = sum(ymm12) + // xmm14[0] = sum(ymm14) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovss(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovss(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovss(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmovss(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) + vmovss(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_3x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 0*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 0*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 0*4)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vmovss(mem(rax, r8, 2), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 + // ymm6 + // ymm8 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm8 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + // xmm8[0] = sum(ymm8) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovss(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovss(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_2x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 0*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 0*4)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rax, r8, 1), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 + // ymm6 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovss(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovss(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_1x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 0*4)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 + // ymm6 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovss(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovss(xmm4, mem(rcx)) + add(rdi, rcx) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovss(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c new file mode 100644 index 0000000000..1d3d88309f --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c @@ -0,0 +1,1526 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(12), r15) // compare jj to 12 + jl(.SLOOP3X4J) // if jj < 12, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_2x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(12), r15) // compare jj to 12 + jl(.SLOOP3X4J) // if jj < 12, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_1x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(12), r15) // compare jj to 12 + jl(.SLOOP3X4J) // if jj < 12, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c new file mode 100644 index 0000000000..bbb75a6fcd --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c @@ -0,0 +1,1648 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 16; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rd_haswell_asm_6x8 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rd_haswell_asm_6x4 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_6x2 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_6x1 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 16 + jl(.SLOOP3X4J) // if jj < 16, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 16; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_haswell_asm_2x16 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_haswell_asm_1x16 + //bli_sgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_haswell_asm_2x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 16 + jl(.SLOOP3X4J) // if jj < 16, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_1x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 16 + jl(.SLOOP3X4J) // if jj < 16, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c new file mode 100644 index 0000000000..1e3240350b --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c @@ -0,0 +1,1828 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*4)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*4)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*4)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovss(mem(rax, r8, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + vmovss(mem(rax, r13, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rax, r8, 4), xmm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + + vmovss(mem(rax, r15, 1), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm8 ) + + vhaddps( ymm11, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm10 ) + + vhaddps( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm12 ) + + vhaddps( ymm15, ymm14, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm14 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + // xmm10[0:1] = sum(ymm10) sum(ymm11) + // xmm12[0:1] = sum(ymm12) sum(ymm13) + // xmm14[0:1] = sum(ymm14) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vmovss(mem(rax, r8, 2), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm8 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rax, r8, 1), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + vhaddps( ymm7, ymm6, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm6 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 8*rs_b = 8*4; + + vmovss(mem(rax ), xmm3) + add(imm(1*4), rax) // a += 8*cs_a = 8*4; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm5 + + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm4 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c new file mode 100644 index 0000000000..9d4e9d51d2 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c @@ -0,0 +1,1457 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) +#endif + + + //lea(mem(r12), rcx) // rcx = c; + //lea(mem(r14), rax) // rax = a; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) +#endif + + + //lea(mem(r12), rcx) // rcx = c; + //lea(mem(r14), rax) // rax = a; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c new file mode 100644 index 0000000000..788912ecf6 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c @@ -0,0 +1,1526 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*4)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm6) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.SLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*4)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm5) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.SLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + //lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c*sizeof(float) = 1*4 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + + label(.SLOOPKITER32) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + + + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 8*cs_a = 8*4; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 8*rs_b = 8*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + + + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_a = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.SPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vhaddps( xmm1, xmm0, xmm0 ) + vpermilps(imm(0xd8), xmm0, xmm0) + vhaddps( xmm0, xmm0, xmm0 ) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm3 ) + vhaddps( xmm3, xmm2, xmm2 ) + vpermilps(imm(0xd8), xmm2, xmm2) + vhaddps( xmm2, xmm2, xmm2 ) + + vshufps(imm(0x44), xmm2, xmm0, xmm4) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(float) + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.SDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.SLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.SRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c new file mode 100644 index 0000000000..1bea78ee73 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c @@ -0,0 +1,3518 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 5*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 5*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 5*8)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(xmm0, xmm7, xmm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(xmm0, xmm9, xmm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(xmm0, xmm11, xmm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(xmm0, xmm13, xmm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm7) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm9) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm11) + vmovups(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm13) + vmovups(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm15) + vmovups(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm15, ymm13, ymm0) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(ymm15, ymm13, ymm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + vmovups(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm15, ymm13, ymm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(ymm15, ymm13, ymm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_5x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 4*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 4*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 4*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 4*8)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(xmm0, xmm7, xmm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(xmm0, xmm9, xmm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(xmm0, xmm11, xmm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(xmm0, xmm13, xmm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm7) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm9) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm11) + vmovups(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm13) + vmovups(xmm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm13, ymm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(xmm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm13, ymm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_4x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 3*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 3*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 3*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 3*8)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(xmm0, xmm7, xmm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(xmm0, xmm9, xmm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(xmm0, xmm11, xmm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm7) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm9) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm11) + vmovups(xmm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(xmm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_3x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 2*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 2*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 2*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 2*8)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(xmm0, xmm7, xmm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(xmm0, xmm9, xmm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm7) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm9) + vmovups(xmm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(ymm7, ymm5, ymm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm9, ymm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(xmm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(ymm7, ymm5, ymm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm9, ymm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_2x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 1*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 1*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 1*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 1*8)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(xmm0, xmm7, xmm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm7) + vmovups(xmm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(ymm7, ymm5, ymm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(xmm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-11 + vunpcklps(ymm7, ymm5, ymm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(ymm7, ymm5, ymm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_1x12 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 0*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 0*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 11*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), xmm3, xmm5) + vmovups(xmm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-11 + vmovups(ymm5, ymm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(xmm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-11 + vmovups(ymm5, ymm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c new file mode 100644 index 0000000000..6a08cecd43 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c @@ -0,0 +1,3855 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 6 +#define FUNCPTR_T sgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 16, 12, 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 16 12 8 4 2 1 */ +/* 6 */ { bli_sgemmsup_rv_haswell_asm_6x16, bli_sgemmsup_rv_haswell_asm_6x12, bli_sgemmsup_rv_haswell_asm_6x8, bli_sgemmsup_rv_haswell_asm_6x4, bli_sgemmsup_rv_haswell_asm_6x2, bli_sgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_sgemmsup_rv_haswell_asm_4x16, bli_sgemmsup_rv_haswell_asm_4x12, bli_sgemmsup_rv_haswell_asm_4x8, bli_sgemmsup_rv_haswell_asm_4x4, bli_sgemmsup_rv_haswell_asm_4x2, bli_sgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_sgemmsup_rv_haswell_asm_2x16, bli_sgemmsup_rv_haswell_asm_2x12, bli_sgemmsup_rv_haswell_asm_2x8, bli_sgemmsup_rv_haswell_asm_2x4, bli_sgemmsup_rv_haswell_asm_2x2, bli_sgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_sgemmsup_rv_haswell_asm_1x16, bli_sgemmsup_rv_haswell_asm_1x12, bli_sgemmsup_rv_haswell_asm_1x8, bli_sgemmsup_rv_haswell_asm_1x4, bli_sgemmsup_rv_haswell_asm_1x2, bli_sgemmsup_r_haswell_ref_1x1 } +}; + + +void bli_sgemmsup_rv_haswell_asm_6x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 16 ) + { +#if 0 + bli_sgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + dim_t n_left = n0; + float* restrict cj = c; + float* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + float* restrict cij = cj; + float* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 5*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 5*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 5*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*8)) // prefetch c + 12*cs_c + lea(mem(rcx, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rbp, 1, 5*8)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm13) + vmovups(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm15) + vmovups(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + vmovups(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_5x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 4*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 4*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 4*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 4*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 4*8)) // prefetch c + 12*cs_c + lea(mem(rcx, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rbp, 1, 4*8)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm13) + vmovups(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + vmovups(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_4x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 3*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 3*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 3*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 3*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 3*8)) // prefetch c + 12*cs_c + lea(mem(rcx, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rbp, 1, 3*8)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm11) + vmovups(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + vmovups(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_3x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 2*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 2*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 2*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 2*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 2*8)) // prefetch c + 12*cs_c + lea(mem(rcx, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rbp, 1, 2*8)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm9) + vmovups(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm9, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + vmovups(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm9, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_2x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 1*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 1*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 1*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 1*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 1*8)) // prefetch c + 12*cs_c + lea(mem(rcx, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rbp, 1, 1*8)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm7) + vmovups(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + vmovups(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vunpcklps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm7, ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_1x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 0*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 0*8)) // prefetch c + 8*cs_c + lea(mem(rcx, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 0*8)) // prefetch c + 12*cs_c + lea(mem(rcx, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 15*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vfmadd231ps(mem(rcx, 1*32), ymm3, ymm5) + vmovups(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vmovups(ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + vmovups(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + // begin I/O on columns 8-15 + vmovups(ymm5, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c new file mode 100644 index 0000000000..6090f8b0b9 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -0,0 +1,2504 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovsd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + vmovsd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_5x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovsd(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm12, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm12, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_4x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm8, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm8, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovsd(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovsd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + //lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-1 + vmovups(xmm4, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma00 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma01 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovsd(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-1 + vmovups(xmm4, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma00 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma01 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c new file mode 100644 index 0000000000..512fd60525 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c @@ -0,0 +1,2668 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovups(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + vmovups(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(xmm14, xmm12, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + + vunpckhps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_5x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovups(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm12, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm12, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_4x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(xmm6, xmm4, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm8, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma20 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma21 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma22 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma23 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(xmm6, xmm4, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(xmm8, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma20 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma21 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma22 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma23 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(xmm6, xmm4, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vunpcklps(xmm6, xmm4, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vunpckhps(xmm6, xmm4, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + + // begin I/O on columns 0-3 + vmovups(xmm4, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma00 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma01 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma02 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma03 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-3 + vmovups(xmm4, xmm0) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma00 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma01 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma02 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma03 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c new file mode 100644 index 0000000000..ac4e1ee0b0 --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c @@ -0,0 +1,3395 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm7) + vmovsd(xmm7, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm9) + vmovsd(xmm9, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm11) + vmovsd(xmm11, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm13) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm12) + vmovups(xmm12, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm13) + vmovsd(xmm13, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm14, xmm15) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm14) + vmovups(xmm14, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm15) + vmovsd(xmm15, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vmovups(xmm6, mem(rcx, 0*4)) + vmovsd(xmm7, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vmovups(xmm8, mem(rcx, 0*4)) + vmovsd(xmm9, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vmovups(xmm10, mem(rcx, 0*4)) + vmovsd(xmm11, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm13) + vmovups(xmm12, mem(rcx, 0*4)) + vmovsd(xmm13, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm14, xmm15) + vmovups(xmm14, mem(rcx, 0*4)) + vmovsd(xmm15, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_5x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*4)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*4)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*4)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*4)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*4)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 4*4)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm7) + vmovsd(xmm7, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm9) + vmovsd(xmm9, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm11) + vmovsd(xmm11, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm13) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm12) + vmovups(xmm12, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm13) + vmovsd(xmm13, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vmovups(xmm6, mem(rcx, 0*4)) + vmovsd(xmm7, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vmovups(xmm8, mem(rcx, 0*4)) + vmovsd(xmm9, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vmovups(xmm10, mem(rcx, 0*4)) + vmovsd(xmm11, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm13) + vmovups(xmm12, mem(rcx, 0*4)) + vmovsd(xmm13, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_4x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*4)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*4)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*4)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*4)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*4)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 3*4)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm7) + vmovsd(xmm7, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm9) + vmovsd(xmm9, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm10) + vmovups(xmm10, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm11) + vmovsd(xmm11, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vmovups(xmm6, mem(rcx, 0*4)) + vmovsd(xmm7, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vmovups(xmm8, mem(rcx, 0*4)) + vmovsd(xmm9, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm10, xmm11) + vmovups(xmm10, mem(rcx, 0*4)) + vmovsd(xmm11, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_3x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*4)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*4)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*4)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*4)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*4)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*4)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 2*4)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm7) + vmovsd(xmm7, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm8) + vmovups(xmm8, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm9) + vmovsd(xmm9, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vmovups(xmm6, mem(rcx, 0*4)) + vmovsd(xmm7, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm8, xmm9) + vmovups(xmm8, mem(rcx, 0*4)) + vmovsd(xmm9, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_2x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*4)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*4)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*4)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*4)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*4)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 1*4)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm6) + vmovups(xmm6, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm7) + vmovsd(xmm7, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm6, xmm7) + vmovups(xmm6, mem(rcx, 0*4)) + vmovsd(xmm7, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-5 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_1x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*4)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*4)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*4)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*4)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*4)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 0*4)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*4), xmm0) + vmovsd(mem(rbx, 4*4), xmm1) + vinsertf128(imm(0x1), xmm1, ymm0, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + //lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vextractf128(imm(0x1), ymm4, xmm5) + vfmadd231ps(mem(rcx, 0*4), xmm3, xmm4) + vmovups(xmm4, mem(rcx, 0*4)) + + vmovsd(mem(rcx, 4*4), xmm1) + vfmadd231ps(xmm1, xmm3, xmm5) + vmovsd(xmm5, mem(rcx, 4*4)) + + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-5 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vextractf128(imm(0x1), ymm4, xmm5) + vmovups(xmm4, mem(rcx, 0*4)) + vmovsd(xmm5, mem(rcx, 4*4)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +#if 0 + +void bli_sgemmsup_rv_haswell_asm_1x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 0*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} +#endif diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c new file mode 100644 index 0000000000..2b1a221ada --- /dev/null +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c @@ -0,0 +1,2895 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) + + +void bli_sgemmsup_rv_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 5*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 5*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx ), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(rdx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rdx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(rdx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx )) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rdx, rbx, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rax, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(rdx, rax, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rbp, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_5x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 4*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 4*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_4x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 3*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 3*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx ), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbx, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, rax, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rax, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, rbp, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx )) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma35 ) + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rax, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 2*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 2*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rdx ), xmm1) + vmovss(mem(rdx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rdx, rsi, 2), xmm1) + vmovss(mem(rdx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rdx, rsi, 4), xmm1) + vmovss(mem(rdx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rdx, rax, 2), xmm1) + vmovss(mem(rdx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + vmovups(ymm8, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rdx )) // store ( gamma40 ) + vmovss(xmm4, mem(rdx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rdx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rdx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rdx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rdx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rdx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rdx, rsi, 8), rdx) // rdx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 1*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 1*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx ), xmm1, xmm1) + vmovhpd(mem(rcx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(mem(rcx, rsi, 4), xmm1, xmm1) + vmovhpd(mem(rcx, rbx, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(mem(rcx, rax, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rbp, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vunpcklps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx )) // store ( gamma00..gamma10 ) + vmovhpd(xmm0, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovlpd(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma14 ) + vmovhpd(xmm2, mem(rcx, rbx, 1)) // store ( gamma05..gamma15 ) + + vunpckhps(ymm6, ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovhpd(xmm0, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + vmovlpd(xmm2, mem(rcx, rax, 2)) // store ( gamma06..gamma16 ) + vmovhpd(xmm2, mem(rcx, rbp, 1)) // store ( gamma07..gamma17 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rsi, 4, 0*8)) // prefetch c + 4*cs_c + lea(mem(rcx, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rbp, 1, 0*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + + + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + + label(.SPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + lea(mem(rax, rsi, 4), rbp) // rbp = 7*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORED) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(mem(rcx ), xmm1) + vmovss(mem(rcx, rsi, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(mem(rcx, rsi, 2), xmm1) + vmovss(mem(rcx, rax, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(mem(rcx, rsi, 4), xmm1) + vmovss(mem(rcx, rbx, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(mem(rcx, rax, 2), xmm1) + vmovss(mem(rcx, rbp, 1), xmm6) + vfmadd231ps(xmm1, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + jmp(.SDONE) // jump to end. + + + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + + + label(.SROWSTORBZ) + + + vmovups(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + + label(.SCOLSTORBZ) + + // begin I/O on columns 0-7 + vmovups(ymm4, ymm0) + vextractf128(imm(0x1), ymm0, xmm8) + + vpermilps(imm(0xe4), xmm0, xmm2) + vpermilps(imm(0x39), xmm0, xmm4) + vmovss(xmm2, mem(rcx )) // store ( gamma40 ) + vmovss(xmm4, mem(rcx, rsi, 1)) // store ( gamma41 ) + + vpermilps(imm(0x4e), xmm0, xmm2) + vpermilps(imm(0x93), xmm0, xmm4) + vmovss(xmm2, mem(rcx, rsi, 2)) // store ( gamma42 ) + vmovss(xmm4, mem(rcx, rax, 1)) // store ( gamma43 ) + + vpermilps(imm(0xe4), xmm8, xmm2) + vpermilps(imm(0x39), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rsi, 4)) // store ( gamma44 ) + vmovss(xmm4, mem(rcx, rbx, 1)) // store ( gamma45 ) + + vpermilps(imm(0x4e), xmm8, xmm2) + vpermilps(imm(0x93), xmm8, xmm4) + vmovss(xmm2, mem(rcx, rax, 2)) // store ( gamma46 ) + vmovss(xmm4, mem(rcx, rbp, 1)) // store ( gamma47 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + + + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 53d434dff8..1c35122a4e 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +33,23 @@ */ -// -- level-3 -- +// -- level-1m ----------------------------------------------------------------- + +// packm (asm) +PACKM_KER_PROT( float, s, packm_haswell_asm_6xk ) +PACKM_KER_PROT( float, s, packm_haswell_asm_16xk ) + +PACKM_KER_PROT( double, d, packm_haswell_asm_6xk ) +PACKM_KER_PROT( double, d, packm_haswell_asm_8xk ) + +PACKM_KER_PROT( scomplex, c, packm_haswell_asm_3xk ) +PACKM_KER_PROT( scomplex, c, packm_haswell_asm_8xk ) + +PACKM_KER_PROT( dcomplex, z, packm_haswell_asm_3xk ) +PACKM_KER_PROT( dcomplex, z, packm_haswell_asm_4xk ) + + +// -- level-3 ------------------------------------------------------------------ // gemm (asm d6x8) GEMM_UKR_PROT( float, s, gemm_haswell_asm_6x16 ) @@ -61,3 +78,214 @@ GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_haswell_asm_6x8 ) //GEMM_UKR_PROT( scomplex, c, gemm_haswell_asm_8x3 ) //GEMM_UKR_PROT( dcomplex, z, gemm_haswell_asm_4x3 ) + +// -- level-3 sup -------------------------------------------------------------- + +// -- single real -- + +// gemmsup_r + +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref_6x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref_5x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref_4x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref_3x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref_2x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref_1x1 ) + +// gemmsup_rv + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x16 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x12 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x8 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x6 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x6 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x6 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x6 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x6 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x6 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x4 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x2 ) + +// gemmsup_rv (mkernel in m dim) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x16m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x12m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x8m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x6m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x4m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x2m ) + +// gemmsup_rv (mkernel in n dim) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_6x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_5x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_4x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_3x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_2x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_haswell_asm_1x16n ) + +// gemmsup_rd + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x16 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x12 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x12 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x8 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x4 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_3x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x2 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_3x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x1 ) + +// gemmsup_rd (mkernel in m dim) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x16m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x12m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x8m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x4m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x2m ) + +// gemmsup_rd (mkernel in n dim) + +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_6x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_3x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_2x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_haswell_asm_1x16n ) + + + +// -- double real -- + +// gemmsup_r + +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 ) + +// gemmsup_rv + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x8 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x6 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x4 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x2 ) + +// gemmsup_rv (mkernel in m dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x6m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x4m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x2m ) + +// gemmsup_rv (mkernel in n dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x8n ) + +// gemmsup_rd + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x8 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x4 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x2 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 ) + +// gemmsup_rd (mkernel in m dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m ) + +// gemmsup_rd (mkernel in n dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x8n ) + diff --git a/kernels/knc/3/bli_dgemm_knc_asm_30x8.c b/kernels/knc/3/bli_dgemm_knc_asm_30x8.c index 880632ae07..f20e43f7cc 100644 --- a/kernels/knc/3/bli_dgemm_knc_asm_30x8.c +++ b/kernels/knc/3/bli_dgemm_knc_asm_30x8.c @@ -256,6 +256,8 @@ extern int offsets[16]; //#define LOOPMON void bli_dgemm_knc_asm_30x8 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -273,80 +275,82 @@ void bli_dgemm_knc_asm_30x8 uint64_t k64 = k; + GEMM_UKR_SETUP_CT( d, 30, 8, true ); + #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; #endif #ifdef LOOPMON int tlooph, tloopl, blooph, bloopl; #endif - + __asm { #ifdef MONITORS rdtsc mov topl, eax - mov toph, edx + mov toph, edx #endif vpxord zmm0, zmm0, zmm0 vmovaps zmm1, zmm0 //clear out registers - vmovaps zmm2, zmm0 + vmovaps zmm2, zmm0 mov rsi, k64 //loop index - vmovaps zmm3, zmm0 + vmovaps zmm3, zmm0 mov r11, rs_c //load row stride - vmovaps zmm4, zmm0 + vmovaps zmm4, zmm0 sal r11, 3 //scale row stride - vmovaps zmm5, zmm0 + vmovaps zmm5, zmm0 mov r15, a //load address of a - vmovaps zmm6, zmm0 + vmovaps zmm6, zmm0 mov rbx, b //load address of b - vmovaps zmm7, zmm0 + vmovaps zmm7, zmm0 - vmovaps zmm8, zmm0 + vmovaps zmm8, zmm0 lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11 vmovaps zmm9, zmm0 - vmovaps zmm10, zmm0 - mov rdi, r11 - vmovaps zmm11, zmm0 + vmovaps zmm10, zmm0 + mov rdi, r11 + vmovaps zmm11, zmm0 sal rdi, 2 //rdi has 4*r11 - vmovaps zmm12, zmm0 + vmovaps zmm12, zmm0 mov rcx, c //load address of c for prefetching - vmovaps zmm13, zmm0 - vmovaps zmm14, zmm0 + vmovaps zmm13, zmm0 + vmovaps zmm14, zmm0 mov r8, k64 - vmovaps zmm15, zmm0 + vmovaps zmm15, zmm0 vmovaps zmm16, zmm0 vmovaps zmm17, zmm0 mov r13, L2_PREFETCH_DIST*8*8 - vmovaps zmm18, zmm0 + vmovaps zmm18, zmm0 mov r14, L2_PREFETCH_DIST*8*32 - vmovaps zmm19, zmm0 - vmovaps zmm20, zmm0 - vmovaps zmm21, zmm0 - vmovaps zmm22, zmm0 + vmovaps zmm19, zmm0 + vmovaps zmm20, zmm0 + vmovaps zmm21, zmm0 + vmovaps zmm22, zmm0 - vmovaps zmm23, zmm0 + vmovaps zmm23, zmm0 sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do. - vmovaps zmm24, zmm0 + vmovaps zmm24, zmm0 mov r8, 30 - vmovaps zmm25, zmm0 + vmovaps zmm25, zmm0 mov r9, 8*8 //amount to increment b* by each iteration - vmovaps zmm26, zmm0 + vmovaps zmm26, zmm0 mov r12, 32*8 //amount to increment a* by each iteration - vmovaps zmm27, zmm0 - vmovaps zmm28, zmm0 - vmovaps zmm29, zmm0 + vmovaps zmm27, zmm0 + vmovaps zmm28, zmm0 + vmovaps zmm29, zmm0 #ifdef MONITORS rdtsc mov midl, eax - mov midh, edx + mov midh, edx #endif jle CONSIDER_UNDER_40 sub rsi, 30 + L2_PREFETCH_DIST - + //First 30 iterations LOOPREFECHCL2: ONE_ITER_PC_L2(rcx) @@ -357,26 +361,26 @@ void bli_dgemm_knc_asm_30x8 LOOPMAIN: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN - + //Penultimate 22 iterations. //Break these off from the main loop to avoid prefetching extra shit. mov r14, a_next mov r13, b_next sub r14, r15 sub r13, rbx - + mov rsi, L2_PREFETCH_DIST-10 LOOPMAIN2: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN2 - - + + //Last 10 iterations mov r8, 10 LOOPREFETCHCL1: ONE_ITER_PC_L1(rcx) jne LOOPREFETCHCL1 - + jmp POSTACCUM @@ -403,14 +407,8 @@ void bli_dgemm_knc_asm_30x8 mov r9, c //load address of c for update mov r12, alpha //load address of alpha - // Check if C is row stride. If not, jump to the slow scattered update - mov r14, cs_c - dec r14 - jne SCATTEREDUPDATE - mov r14, beta - vbroadcastsd zmm31, 0[r14] - + vbroadcastsd zmm31, 0[r14] vmulpd zmm0, zmm0, 0[r12]{1to8} vmulpd zmm1, zmm1, 0[r12]{1to8} @@ -467,7 +465,7 @@ void bli_dgemm_knc_asm_30x8 vmovapd [r9+2*r11+0], zmm14 vmovapd [r9+r10+0], zmm15 add r9, rdi - + vmulpd zmm16, zmm16, 0[r12]{1to8} vmulpd zmm17, zmm17, 0[r12]{1to8} vmulpd zmm18, zmm18, 0[r12]{1to8} @@ -516,47 +514,6 @@ void bli_dgemm_knc_asm_30x8 vfmadd231pd zmm29, zmm31, [r9+r11+0] vmovapd [r9+0], zmm28 vmovapd [r9+r11+0], zmm29 - - jmp END - - SCATTEREDUPDATE: - mov r10, offsetPtr - vmovapd zmm31, 0[r10] - vpbroadcastd zmm30, cs_c - mov r13, beta - vpmulld zmm30, zmm31, zmm30 - - mov ebx, 255 - UPDATE_C_ROW_SCATTERED(zmm0, 0, r9) - UPDATE_C_ROW_SCATTERED(zmm1, 1, r9) - UPDATE_C_ROW_SCATTERED(zmm2, 2, r9) - UPDATE_C_ROW_SCATTERED(zmm3, 3, r9) - UPDATE_C_ROW_SCATTERED(zmm4, 4, r9) - UPDATE_C_ROW_SCATTERED(zmm5, 5, r9) - UPDATE_C_ROW_SCATTERED(zmm6, 6, r9) - UPDATE_C_ROW_SCATTERED(zmm7, 7, r9) - UPDATE_C_ROW_SCATTERED(zmm8, 8, r9) - UPDATE_C_ROW_SCATTERED(zmm9, 9, r9) - UPDATE_C_ROW_SCATTERED(zmm10, 10, r9) - UPDATE_C_ROW_SCATTERED(zmm11, 11, r9) - UPDATE_C_ROW_SCATTERED(zmm12, 12, r9) - UPDATE_C_ROW_SCATTERED(zmm13, 13, r9) - UPDATE_C_ROW_SCATTERED(zmm14, 14, r9) - UPDATE_C_ROW_SCATTERED(zmm15, 15, r9) - UPDATE_C_ROW_SCATTERED(zmm16, 16, r9) - UPDATE_C_ROW_SCATTERED(zmm17, 17, r9) - UPDATE_C_ROW_SCATTERED(zmm18, 18, r9) - UPDATE_C_ROW_SCATTERED(zmm19, 19, r9) - UPDATE_C_ROW_SCATTERED(zmm20, 20, r9) - UPDATE_C_ROW_SCATTERED(zmm21, 21, r9) - UPDATE_C_ROW_SCATTERED(zmm22, 22, r9) - UPDATE_C_ROW_SCATTERED(zmm23, 23, r9) - UPDATE_C_ROW_SCATTERED(zmm24, 24, r9) - UPDATE_C_ROW_SCATTERED(zmm25, 25, r9) - UPDATE_C_ROW_SCATTERED(zmm26, 26, r9) - UPDATE_C_ROW_SCATTERED(zmm27, 27, r9) - UPDATE_C_ROW_SCATTERED(zmm28, 28, r9) - UPDATE_C_ROW_SCATTERED(zmm29, 29, r9) END: #ifdef MONITORS @@ -566,6 +523,8 @@ void bli_dgemm_knc_asm_30x8 #endif } + GEMM_UKR_FLUSH_CT( d ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knc/3/bli_sgemm_knc_asm_30x16.c b/kernels/knc/3/bli_sgemm_knc_asm_30x16.c index 866cb62ec1..18a8e5e2ee 100644 --- a/kernels/knc/3/bli_sgemm_knc_asm_30x16.c +++ b/kernels/knc/3/bli_sgemm_knc_asm_30x16.c @@ -256,6 +256,8 @@ int offsets[16] __attribute__((aligned(0x1000))) = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9 //#define LOOPMON void bli_sgemm_knc_asm_30x16 ( + dim_t m, + dim_t n, dim_t k, float* restrict alpha, float* restrict a, @@ -273,80 +275,82 @@ void bli_sgemm_knc_asm_30x16 uint64_t k64 = k; + GEMM_UKR_SETUP_CT( s, 30, 16, true ); + #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; #endif #ifdef LOOPMON int tlooph, tloopl, blooph, bloopl; #endif - + __asm { #ifdef MONITORS rdtsc mov topl, eax - mov toph, edx + mov toph, edx #endif vpxord zmm0, zmm0, zmm0 vmovaps zmm1, zmm0 //clear out registers - vmovaps zmm2, zmm0 + vmovaps zmm2, zmm0 mov rsi, k64 //loop index - vmovaps zmm3, zmm0 + vmovaps zmm3, zmm0 mov r11, rs_c //load row stride - vmovaps zmm4, zmm0 + vmovaps zmm4, zmm0 sal r11, 2 //scale row stride - vmovaps zmm5, zmm0 + vmovaps zmm5, zmm0 mov r15, a //load address of a - vmovaps zmm6, zmm0 + vmovaps zmm6, zmm0 mov rbx, b //load address of b - vmovaps zmm7, zmm0 + vmovaps zmm7, zmm0 - vmovaps zmm8, zmm0 + vmovaps zmm8, zmm0 lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11 vmovaps zmm9, zmm0 - vmovaps zmm10, zmm0 - mov rdi, r11 - vmovaps zmm11, zmm0 + vmovaps zmm10, zmm0 + mov rdi, r11 + vmovaps zmm11, zmm0 sal rdi, 2 //rdi has 4*r11 - vmovaps zmm12, zmm0 + vmovaps zmm12, zmm0 mov rcx, c //load address of c for prefetching - vmovaps zmm13, zmm0 - vmovaps zmm14, zmm0 + vmovaps zmm13, zmm0 + vmovaps zmm14, zmm0 mov r8, k64 - vmovaps zmm15, zmm0 + vmovaps zmm15, zmm0 vmovaps zmm16, zmm0 vmovaps zmm17, zmm0 mov r13, L2_PREFETCH_DIST*4*16 - vmovaps zmm18, zmm0 + vmovaps zmm18, zmm0 mov r14, L2_PREFETCH_DIST*4*32 - vmovaps zmm19, zmm0 - vmovaps zmm20, zmm0 - vmovaps zmm21, zmm0 - vmovaps zmm22, zmm0 + vmovaps zmm19, zmm0 + vmovaps zmm20, zmm0 + vmovaps zmm21, zmm0 + vmovaps zmm22, zmm0 - vmovaps zmm23, zmm0 + vmovaps zmm23, zmm0 sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do. - vmovaps zmm24, zmm0 + vmovaps zmm24, zmm0 mov r8, 30 - vmovaps zmm25, zmm0 + vmovaps zmm25, zmm0 mov r9, 16*4 //amount to increment b* by each iteration - vmovaps zmm26, zmm0 + vmovaps zmm26, zmm0 mov r12, 32*4 //amount to increment a* by each iteration - vmovaps zmm27, zmm0 - vmovaps zmm28, zmm0 - vmovaps zmm29, zmm0 + vmovaps zmm27, zmm0 + vmovaps zmm28, zmm0 + vmovaps zmm29, zmm0 #ifdef MONITORS rdtsc mov midl, eax - mov midh, edx + mov midh, edx #endif jle CONSIDER_UNDER_40 sub rsi, 30 + L2_PREFETCH_DIST - + //First 30 iterations LOOPREFECHCL2: ONE_ITER_PC_L2(rcx) @@ -357,26 +361,26 @@ void bli_sgemm_knc_asm_30x16 LOOPMAIN: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN - + //Penultimate 22 iterations. //Break these off from the main loop to avoid prefetching extra shit. mov r14, a_next mov r13, b_next sub r14, r15 sub r13, rbx - + mov rsi, L2_PREFETCH_DIST-10 LOOPMAIN2: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN2 - - + + //Last 10 iterations mov r8, 10 LOOPREFETCHCL1: ONE_ITER_PC_L1(rcx) jne LOOPREFETCHCL1 - + jmp POSTACCUM @@ -384,7 +388,7 @@ void bli_sgemm_knc_asm_30x16 //Used when <= 40 iterations CONSIDER_UNDER_40: mov rsi, k64 - test rsi, rsi + test rsi, rsi je POSTACCUM LOOP_UNDER_40: ONE_ITER_MAIN_LOOP(rcx, rsi) @@ -403,13 +407,8 @@ void bli_sgemm_knc_asm_30x16 mov r9, c //load address of c for update mov r12, alpha //load address of alpha - // Check if C is row stride. If not, jump to the slow scattered update - mov r14, cs_c - dec r14 - jne SCATTEREDUPDATE - mov r14, beta - vbroadcastss zmm31, 0[r14] + vbroadcastss zmm31, 0[r14] vmulps zmm0, zmm0, 0[r12]{1to16} @@ -467,7 +466,7 @@ void bli_sgemm_knc_asm_30x16 vmovaps [r9+2*r11+0], zmm14 vmovaps [r9+r10+0], zmm15 add r9, rdi - + vmulps zmm16, zmm16, 0[r12]{1to16} vmulps zmm17, zmm17, 0[r12]{1to16} vmulps zmm18, zmm18, 0[r12]{1to16} @@ -516,48 +515,6 @@ void bli_sgemm_knc_asm_30x16 vfmadd231ps zmm29, zmm31, [r9+r11+0] vmovaps [r9+0], zmm28 vmovaps [r9+r11+0], zmm29 - - jmp END - - SCATTEREDUPDATE: - - mov r10, offsetPtr - vmovaps zmm31, 0[r10] - vpbroadcastd zmm30, cs_c - mov r13, beta - vpmulld zmm30, zmm31, zmm30 - - mov ebx, 0xFFFF - UPDATE_C_ROW_SCATTERED(zmm0, 0, r9) - UPDATE_C_ROW_SCATTERED(zmm1, 1, r9) - UPDATE_C_ROW_SCATTERED(zmm2, 2, r9) - UPDATE_C_ROW_SCATTERED(zmm3, 3, r9) - UPDATE_C_ROW_SCATTERED(zmm4, 4, r9) - UPDATE_C_ROW_SCATTERED(zmm5, 5, r9) - UPDATE_C_ROW_SCATTERED(zmm6, 6, r9) - UPDATE_C_ROW_SCATTERED(zmm7, 7, r9) - UPDATE_C_ROW_SCATTERED(zmm8, 8, r9) - UPDATE_C_ROW_SCATTERED(zmm9, 9, r9) - UPDATE_C_ROW_SCATTERED(zmm10, 10, r9) - UPDATE_C_ROW_SCATTERED(zmm11, 11, r9) - UPDATE_C_ROW_SCATTERED(zmm12, 12, r9) - UPDATE_C_ROW_SCATTERED(zmm13, 13, r9) - UPDATE_C_ROW_SCATTERED(zmm14, 14, r9) - UPDATE_C_ROW_SCATTERED(zmm15, 15, r9) - UPDATE_C_ROW_SCATTERED(zmm16, 16, r9) - UPDATE_C_ROW_SCATTERED(zmm17, 17, r9) - UPDATE_C_ROW_SCATTERED(zmm18, 18, r9) - UPDATE_C_ROW_SCATTERED(zmm19, 19, r9) - UPDATE_C_ROW_SCATTERED(zmm20, 20, r9) - UPDATE_C_ROW_SCATTERED(zmm21, 21, r9) - UPDATE_C_ROW_SCATTERED(zmm22, 22, r9) - UPDATE_C_ROW_SCATTERED(zmm23, 23, r9) - UPDATE_C_ROW_SCATTERED(zmm24, 24, r9) - UPDATE_C_ROW_SCATTERED(zmm25, 25, r9) - UPDATE_C_ROW_SCATTERED(zmm26, 26, r9) - UPDATE_C_ROW_SCATTERED(zmm27, 27, r9) - UPDATE_C_ROW_SCATTERED(zmm28, 28, r9) - UPDATE_C_ROW_SCATTERED(zmm29, 29, r9) END: #ifdef MONITORS @@ -567,6 +524,8 @@ void bli_sgemm_knc_asm_30x16 #endif } + GEMM_UKR_FLUSH_CT( s ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c b/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c index 89b712a091..91fe1989f0 100644 --- a/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c +++ b/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c @@ -109,12 +109,13 @@ static int32_t offsets[32] __attribute__((aligned(64))) = void bli_dpackm_knl_asm_8xk ( conj_t conja, + pack_t schema, dim_t cdim_, dim_t n_, dim_t n_max_, - void* restrict kappa_, - void* restrict a_, inc_t inca_, inc_t lda_, - void* restrict p_, inc_t ldp_, + double* restrict kappa_, + double* restrict a_, inc_t inca_, inc_t lda_, + double* restrict p_, inc_t ldp_, cntx_t* restrict cntx ) { @@ -359,12 +360,13 @@ void bli_dpackm_knl_asm_8xk void bli_dpackm_knl_asm_24xk ( conj_t conja, + pack_t schema, dim_t cdim_, dim_t n_, dim_t n_max_, - void* restrict kappa_, - void* restrict a_, inc_t inca_, inc_t lda_, - void* restrict p_, inc_t ldp_, + double* restrict kappa_, + double* restrict a_, inc_t inca_, inc_t lda_, + double* restrict p_, inc_t ldp_, cntx_t* restrict cntx ) { diff --git a/kernels/knl/1m/bli_spackm_knl_asm_24x16.c b/kernels/knl/1m/bli_spackm_knl_asm_24x16.c index b1918289b6..8c4bdfe6be 100644 --- a/kernels/knl/1m/bli_spackm_knl_asm_24x16.c +++ b/kernels/knl/1m/bli_spackm_knl_asm_24x16.c @@ -111,12 +111,13 @@ static int32_t offsets[32] __attribute__((aligned(64))) = void bli_spackm_knl_asm_16xk ( conj_t conja, + pack_t schema, dim_t cdim_, dim_t n_, dim_t n_max_, - void* restrict kappa_, - void* restrict a_, inc_t inca_, inc_t lda_, - void* restrict p_, inc_t ldp_, + float* restrict kappa_, + float* restrict a_, inc_t inca_, inc_t lda_, + float* restrict p_, inc_t ldp_, cntx_t* restrict cntx ) { @@ -377,12 +378,13 @@ void bli_spackm_knl_asm_16xk void bli_spackm_knl_asm_24xk ( conj_t conja, + pack_t schema, dim_t cdim_, dim_t n_, dim_t n_max_, - void* restrict kappa_, - void* restrict a_, inc_t inca_, inc_t lda_, - void* restrict p_, inc_t ldp_, + float* restrict kappa_, + float* restrict a_, inc_t inca_, inc_t lda_, + float* restrict p_, inc_t ldp_, cntx_t* restrict cntx ) { diff --git a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c index b794e7c059..a7f860ae02 100644 --- a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c +++ b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c @@ -185,6 +185,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) = //#define LOOPMON void bli_dgemm_knl_asm_24x8 ( + dim_t m, + dim_t n, dim_t k_, double* restrict alpha, double* restrict a, @@ -201,10 +203,12 @@ void bli_dgemm_knl_asm_24x8 const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); - const int32_t * offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int32_t * offsetPtr = &offsets[0]; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 24, 8, true ); #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; @@ -565,10 +569,7 @@ void bli_dgemm_knl_asm_24x8 // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) LEA(RAX, MEM(,RAX,8)) - MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) VMOVQ(RDX, XMM(1)) SAL(RDX) //shift out sign bit @@ -592,74 +593,6 @@ void bli_dgemm_knl_asm_24x8 UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) - JMP(END) - - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(2), MEM(RDI)) - /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) - VPMULLD(ZMM(2), ZMM(3), ZMM(2)) - - VMOVQ(RDX, XMM(1)) - SAL(RDX) //shift out sign bit - JZ(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8) - UPDATE_C_ROW_SCATTERED( 9) - UPDATE_C_ROW_SCATTERED(10) - UPDATE_C_ROW_SCATTERED(11) - UPDATE_C_ROW_SCATTERED(12) - UPDATE_C_ROW_SCATTERED(13) - UPDATE_C_ROW_SCATTERED(14) - UPDATE_C_ROW_SCATTERED(15) - UPDATE_C_ROW_SCATTERED(16) - UPDATE_C_ROW_SCATTERED(17) - UPDATE_C_ROW_SCATTERED(18) - UPDATE_C_ROW_SCATTERED(19) - UPDATE_C_ROW_SCATTERED(20) - UPDATE_C_ROW_SCATTERED(21) - UPDATE_C_ROW_SCATTERED(22) - UPDATE_C_ROW_SCATTERED(23) - UPDATE_C_ROW_SCATTERED(24) - UPDATE_C_ROW_SCATTERED(25) - UPDATE_C_ROW_SCATTERED(26) - UPDATE_C_ROW_SCATTERED(27) - UPDATE_C_ROW_SCATTERED(28) - UPDATE_C_ROW_SCATTERED(29) - UPDATE_C_ROW_SCATTERED(30) - UPDATE_C_ROW_SCATTERED(31) - - JMP(END) - - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8) - UPDATE_C_BZ_ROW_SCATTERED( 9) - UPDATE_C_BZ_ROW_SCATTERED(10) - UPDATE_C_BZ_ROW_SCATTERED(11) - UPDATE_C_BZ_ROW_SCATTERED(12) - UPDATE_C_BZ_ROW_SCATTERED(13) - UPDATE_C_BZ_ROW_SCATTERED(14) - UPDATE_C_BZ_ROW_SCATTERED(15) - UPDATE_C_BZ_ROW_SCATTERED(16) - UPDATE_C_BZ_ROW_SCATTERED(17) - UPDATE_C_BZ_ROW_SCATTERED(18) - UPDATE_C_BZ_ROW_SCATTERED(19) - UPDATE_C_BZ_ROW_SCATTERED(20) - UPDATE_C_BZ_ROW_SCATTERED(21) - UPDATE_C_BZ_ROW_SCATTERED(22) - UPDATE_C_BZ_ROW_SCATTERED(23) - UPDATE_C_BZ_ROW_SCATTERED(24) - UPDATE_C_BZ_ROW_SCATTERED(25) - UPDATE_C_BZ_ROW_SCATTERED(26) - UPDATE_C_BZ_ROW_SCATTERED(27) - UPDATE_C_BZ_ROW_SCATTERED(28) - UPDATE_C_BZ_ROW_SCATTERED(29) - UPDATE_C_BZ_ROW_SCATTERED(30) - UPDATE_C_BZ_ROW_SCATTERED(31) - LABEL(END) #ifdef MONITORS @@ -701,6 +634,8 @@ void bli_dgemm_knl_asm_24x8 "zmm30", "zmm31", "memory" ) + GEMM_UKR_FLUSH_CT( d ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c index 8657d24ef5..64feba09f1 100644 --- a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c +++ b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c @@ -182,13 +182,15 @@ static int32_t offsets[32] __attribute__((aligned(64))) = //#define LOOPMON void bli_sgemm_knl_asm_24x16 ( + dim_t m, + dim_t n, dim_t k_, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, cntx_t* restrict cntx ) { @@ -198,10 +200,12 @@ void bli_sgemm_knl_asm_24x16 const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); - const int32_t * offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int32_t * offsetPtr = &offsets[0]; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( s, 24, 16, true ); #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; @@ -562,10 +566,7 @@ void bli_sgemm_knl_asm_24x16 // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) LEA(RAX, MEM(,RAX,4)) - MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) VMOVD(EDX, XMM(1)) SAL(EDX) //shift out sign bit @@ -589,74 +590,6 @@ void bli_sgemm_knl_asm_24x16 UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) - JMP(END) - - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(2), MEM(RDI)) - /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) - VPMULLD(ZMM(2), ZMM(3), ZMM(2)) - - VMOVD(EDX, XMM(1)) - SAL(EDX) //shift out sign bit - JZ(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8) - UPDATE_C_ROW_SCATTERED( 9) - UPDATE_C_ROW_SCATTERED(10) - UPDATE_C_ROW_SCATTERED(11) - UPDATE_C_ROW_SCATTERED(12) - UPDATE_C_ROW_SCATTERED(13) - UPDATE_C_ROW_SCATTERED(14) - UPDATE_C_ROW_SCATTERED(15) - UPDATE_C_ROW_SCATTERED(16) - UPDATE_C_ROW_SCATTERED(17) - UPDATE_C_ROW_SCATTERED(18) - UPDATE_C_ROW_SCATTERED(19) - UPDATE_C_ROW_SCATTERED(20) - UPDATE_C_ROW_SCATTERED(21) - UPDATE_C_ROW_SCATTERED(22) - UPDATE_C_ROW_SCATTERED(23) - UPDATE_C_ROW_SCATTERED(24) - UPDATE_C_ROW_SCATTERED(25) - UPDATE_C_ROW_SCATTERED(26) - UPDATE_C_ROW_SCATTERED(27) - UPDATE_C_ROW_SCATTERED(28) - UPDATE_C_ROW_SCATTERED(29) - UPDATE_C_ROW_SCATTERED(30) - UPDATE_C_ROW_SCATTERED(31) - - JMP(END) - - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8) - UPDATE_C_BZ_ROW_SCATTERED( 9) - UPDATE_C_BZ_ROW_SCATTERED(10) - UPDATE_C_BZ_ROW_SCATTERED(11) - UPDATE_C_BZ_ROW_SCATTERED(12) - UPDATE_C_BZ_ROW_SCATTERED(13) - UPDATE_C_BZ_ROW_SCATTERED(14) - UPDATE_C_BZ_ROW_SCATTERED(15) - UPDATE_C_BZ_ROW_SCATTERED(16) - UPDATE_C_BZ_ROW_SCATTERED(17) - UPDATE_C_BZ_ROW_SCATTERED(18) - UPDATE_C_BZ_ROW_SCATTERED(19) - UPDATE_C_BZ_ROW_SCATTERED(20) - UPDATE_C_BZ_ROW_SCATTERED(21) - UPDATE_C_BZ_ROW_SCATTERED(22) - UPDATE_C_BZ_ROW_SCATTERED(23) - UPDATE_C_BZ_ROW_SCATTERED(24) - UPDATE_C_BZ_ROW_SCATTERED(25) - UPDATE_C_BZ_ROW_SCATTERED(26) - UPDATE_C_BZ_ROW_SCATTERED(27) - UPDATE_C_BZ_ROW_SCATTERED(28) - UPDATE_C_BZ_ROW_SCATTERED(29) - UPDATE_C_BZ_ROW_SCATTERED(30) - UPDATE_C_BZ_ROW_SCATTERED(31) - LABEL(END) #ifdef MONITORS @@ -698,6 +631,8 @@ void bli_sgemm_knl_asm_24x16 "zmm30", "zmm31", "memory" ) + GEMM_UKR_FLUSH_CT( s ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knl/bli_kernels_knl.h b/kernels/knl/bli_kernels_knl.h index 652793cde9..f0b17c49d0 100644 --- a/kernels/knl/bli_kernels_knl.h +++ b/kernels/knl/bli_kernels_knl.h @@ -32,11 +32,12 @@ */ -GEMM_UKR_PROT( double, s, gemm_knl_asm_24x16 ) +GEMM_UKR_PROT( float, s, gemm_knl_asm_24x16 ) + GEMM_UKR_PROT( double, d, gemm_knl_asm_24x8 ) -PACKM_KER_PROT( double, s, packm_knl_asm_24xk ) -PACKM_KER_PROT( double, s, packm_knl_asm_16xk ) +PACKM_KER_PROT( float, s, packm_knl_asm_24xk ) +PACKM_KER_PROT( float, s, packm_knl_asm_16xk ) PACKM_KER_PROT( double, d, packm_knl_asm_24xk ) PACKM_KER_PROT( double, d, packm_knl_asm_8xk ) diff --git a/kernels/penryn/1/bli_axpyv_penryn_int.c b/kernels/penryn/1/bli_axpyv_penryn_int.c index 5031e0a329..53904b6452 100644 --- a/kernels/penryn/1/bli_axpyv_penryn_int.c +++ b/kernels/penryn/1/bli_axpyv_penryn_int.c @@ -73,7 +73,7 @@ void bli_daxpyv_penryn_int v2df_t x1v, x2v, x3v, x4v; v2df_t y1v, y2v, y3v, y4v; - bool_t use_ref = FALSE; + bool use_ref = FALSE; if ( bli_zero_dim1( n ) ) return; diff --git a/kernels/penryn/1/bli_dotv_penryn_int.c b/kernels/penryn/1/bli_dotv_penryn_int.c index c5f0a434a7..4d39b3641d 100644 --- a/kernels/penryn/1/bli_dotv_penryn_int.c +++ b/kernels/penryn/1/bli_dotv_penryn_int.c @@ -71,7 +71,7 @@ void bli_ddotv_penryn_int v2df_t rho1v; v2df_t x1v, y1v; - bool_t use_ref = FALSE; + bool use_ref = FALSE; // If the vector lengths are zero, set rho to zero and return. if ( bli_zero_dim1( n ) ) diff --git a/kernels/penryn/1f/bli_axpy2v_penryn_int.c b/kernels/penryn/1f/bli_axpy2v_penryn_int.c index a76ed29159..5e8a2a9a1f 100644 --- a/kernels/penryn/1f/bli_axpy2v_penryn_int.c +++ b/kernels/penryn/1f/bli_axpy2v_penryn_int.c @@ -79,7 +79,7 @@ void bli_daxpy2v_penryn_int v2df_t x1v, y1v, z1v; v2df_t x2v, y2v, z2v; - bool_t use_ref = FALSE; + bool use_ref = FALSE; if ( bli_zero_dim1( n ) ) return; diff --git a/kernels/penryn/1f/bli_axpyf_penryn_int.c b/kernels/penryn/1f/bli_axpyf_penryn_int.c index 7faac4494d..66bb88ec6f 100644 --- a/kernels/penryn/1f/bli_axpyf_penryn_int.c +++ b/kernels/penryn/1f/bli_axpyf_penryn_int.c @@ -81,7 +81,7 @@ void bli_daxpyf_penryn_int v2df_t a10v, a11v, a12v, a13v, y1v; v2df_t chi0v, chi1v, chi2v, chi3v; - bool_t use_ref = FALSE; + bool use_ref = FALSE; if ( bli_zero_dim2( m, b_n ) ) return; diff --git a/kernels/penryn/1f/bli_dotaxpyv_penryn_int.c b/kernels/penryn/1f/bli_dotaxpyv_penryn_int.c index c72b5447bf..7602a7f282 100644 --- a/kernels/penryn/1f/bli_dotaxpyv_penryn_int.c +++ b/kernels/penryn/1f/bli_dotaxpyv_penryn_int.c @@ -77,7 +77,7 @@ void bli_ddotaxpyv_penryn_int v2df_t alphav, rhov; v2df_t x1v, y1v, z1v; - bool_t use_ref = FALSE; + bool use_ref = FALSE; // If the vector lengths are zero, set rho to zero and return. if ( bli_zero_dim1( n ) ) diff --git a/kernels/penryn/1f/bli_dotxaxpyf_penryn_int.c b/kernels/penryn/1f/bli_dotxaxpyf_penryn_int.c index 1a86a4e566..2deb4a4574 100644 --- a/kernels/penryn/1f/bli_dotxaxpyf_penryn_int.c +++ b/kernels/penryn/1f/bli_dotxaxpyf_penryn_int.c @@ -96,7 +96,7 @@ void bli_ddotxaxpyf_penryn_int v2df_t w2v, z2v; v2df_t psi0v, psi1v, betav, alphav; - bool_t use_ref = FALSE; + bool use_ref = FALSE; if ( bli_zero_dim1( b_n ) ) return; diff --git a/kernels/penryn/1f/bli_dotxf_penryn_int.c b/kernels/penryn/1f/bli_dotxf_penryn_int.c index 0df230d8fd..ad9dc5fbd1 100644 --- a/kernels/penryn/1f/bli_dotxf_penryn_int.c +++ b/kernels/penryn/1f/bli_dotxf_penryn_int.c @@ -82,7 +82,7 @@ void bli_ddotxf_penryn_int v2df_t rho0v, rho1v, rho2v, rho3v; v2df_t x0v, x1v, x2v, x3v, y0v, betav, alphav; - bool_t use_ref = FALSE; + bool use_ref = FALSE; if ( bli_zero_dim1( b_n ) ) return; diff --git a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c index e52cc9e0e0..a3e39c3ac1 100644 --- a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c @@ -39,7 +39,9 @@ void bli_sgemm_penryn_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -54,38 +56,40 @@ void bli_sgemm_penryn_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 4, false, 16 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r9) // load address of b_next. - + sub(imm(0-8*16), rax) // increment pointers to allow byte sub(imm(0-8*16), rbx) // offsets in the unrolled iterations. - + movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements movaps(mem(rax, -7*16), xmm1) // of a and b. movaps(mem(rbx, -8*16), xmm2) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) mov(rdi, r12) // make a copy of cs_c (in bytes) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(2, mem(r9, 0*4)) // prefetch b_next - + xorps(xmm3, xmm3) xorps(xmm4, xmm4) xorps(xmm5, xmm5) xorps(xmm6, xmm6) - + prefetch(2, mem(rcx, 6*4)) // prefetch c + 0*cs_c xorps(xmm8, xmm8) xorps(xmm9, xmm9) @@ -98,33 +102,33 @@ void bli_sgemm_penryn_asm_8x4 prefetch(2, mem(r10, rdi, 1, 6*4)) // prefetch c + 3*cs_c xorps(xmm14, xmm14) xorps(xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - + prefetch(0, mem(rax, (4*35+1)*8)) - + addps(xmm6, xmm10) // iteration 0 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -7*16), xmm2) addps(xmm3, xmm12) @@ -132,7 +136,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -140,22 +144,22 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -6*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -5*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 1 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -6*16), xmm2) addps(xmm3, xmm12) @@ -163,7 +167,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -171,22 +175,22 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -4*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -3*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 2 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -5*16), xmm2) addps(xmm3, xmm12) @@ -194,7 +198,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -202,26 +206,26 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -2*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -1*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 3 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + sub(imm(0-4*8*4), rax) // a += 4*8 (unroll x mr) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + sub(imm(0-4*4*4), r9) // b_next += 4*4 (unroll x nr) - + addps(xmm2, xmm8) movaps(mem(rbx, -4*16), xmm2) addps(xmm3, xmm12) @@ -229,9 +233,9 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + sub(imm(0-4*4*4), rbx) // b += 4*4 (unroll x nr) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -239,40 +243,40 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -8*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -7*16), xmm1) - + prefetch(2, mem(r9, 0*4)) // prefetch b_next[0] prefetch(2, mem(r9, 16*4)) // prefetch b_next[16] - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + addps(xmm6, xmm10) // iteration 0 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -7*16), xmm2) addps(xmm3, xmm12) @@ -280,7 +284,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -288,40 +292,40 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -6*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -5*16), xmm1) - + sub(imm(0-1*8*4), rax) // a += 8 (1 x mr) sub(imm(0-1*4*4), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + addps(xmm6, xmm10) addps(xmm3, xmm14) addps(xmm4, xmm11) addps(xmm5, xmm15) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta movss(mem(rax), xmm6) // load alpha to bottom 4 bytes of xmm6 movss(mem(rbx), xmm7) // load beta to bottom 4 bytes of xmm7 pshufd(imm(0x00), xmm6, xmm6) // populate xmm6 with four alphas pshufd(imm(0x00), xmm7, xmm7) // populate xmm7 with four betas - - + + mov(var(rs_c), rsi) // load rs_c mov(rsi, r8) // make a copy of rs_c - + lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) lea(mem(rsi, rsi, 2), r11) // r11 = 3*(rs_c * sizeof(float)) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + // xmm8: xmm9: xmm10: xmm11: // ( ab00 ( ab01 ( ab02 ( ab03 // ab11 ab12 ab13 ab10 @@ -338,20 +342,20 @@ void bli_sgemm_penryn_asm_8x4 shufps(imm(0xd8), xmm11, xmm8) shufps(imm(0xd8), xmm10, xmm11) shufps(imm(0xd8), xmm4, xmm10) - + movaps(xmm8, xmm4) shufps(imm(0xd8), xmm10, xmm8) shufps(imm(0xd8), xmm4, xmm10) movaps(xmm9, xmm5) shufps(imm(0xd8), xmm11, xmm9) shufps(imm(0xd8), xmm5, xmm11) - + movaps(xmm13, xmm4) shufps(imm(0xd8), xmm12, xmm13) shufps(imm(0xd8), xmm15, xmm12) shufps(imm(0xd8), xmm14, xmm15) shufps(imm(0xd8), xmm4, xmm14) - + movaps(xmm12, xmm4) shufps(imm(0xd8), xmm14, xmm12) shufps(imm(0xd8), xmm4, xmm14) @@ -369,471 +373,133 @@ void bli_sgemm_penryn_asm_8x4 // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - - - // determine if - // c % 16 == 0, AND - // 8*cs_c % 16 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(1), r8) // set ZF if rs_c == 1. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(15), rcx) // set ZF if c & 16 is zero. - setz(bh) // bh = ( ZF == 1 ? 1 : 0 ); - test(imm(15), r12) // set ZF if (4*cs_c) & 16 is zero. - setz(al) // al = ( ZF == 1 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + xorpd(xmm0, xmm0) // set xmm0 to zero. ucomisd(xmm0, xmm7) // check if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - movlps(mem(rcx), xmm0) // load c00 ~ c30 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c40 ~ c70 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm12, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c01 ~ c31 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c41 ~ c71 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm13, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c02 ~ c32 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c42 ~ c72 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm14, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c03 ~ c33 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - - - - movlps(mem(rdx), xmm0) // load c43 ~ c73 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm15, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - movaps(mem(rcx), xmm0) // load c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - - - movaps(mem(rdx), xmm1) // load c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - - jmp(.SDONE) // jump to end. - - - - + + movaps(mem(rcx), xmm0) // load c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + + + movaps(mem(rdx), xmm1) // load c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - - - - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - // skip loading c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. - - // skip loading c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. - - - - - - - - + + // skip loading c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. + + // skip loading c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. + label(.SDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "xmm0", "xmm1", "xmm2", "xmm3", @@ -842,11 +508,15 @@ void bli_sgemm_penryn_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_penryn_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -861,39 +531,41 @@ void bli_dgemm_penryn_asm_4x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( d, 4, 4, false, 16 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r9) // load address of b_next. mov(var(a_next), r11) // load address of a_next. - + sub(imm(0-8*16), rax) // increment pointers to allow byte sub(imm(0-8*16), rbx) // offsets in the unrolled iterations. - + movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements movaps(mem(rax, -7*16), xmm1) // of a and b. movaps(mem(rbx, -8*16), xmm2) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) mov(rdi, r12) // make a copy of cs_c (in bytes) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(2, mem(r9, 0*8)) // prefetch b_next - + xorpd(xmm3, xmm3) xorpd(xmm4, xmm4) xorpd(xmm5, xmm5) xorpd(xmm6, xmm6) - + prefetch(2, mem(rcx, 3*8)) // prefetch c + 0*cs_c xorpd(xmm8, xmm8) xorpd(xmm9, xmm9) @@ -906,22 +578,22 @@ void bli_dgemm_penryn_asm_4x4 prefetch(2, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c xorpd(xmm14, xmm14) xorpd(xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - + prefetch(0, mem(rax, (4*35+1)*8)) //prefetch(0, mem(rax, (8*97+4)*8)) - + //prefetch(0, mem(r11, 67*4*8)) // prefetch a_next[0] - + addpd(xmm3, xmm11) // iteration 0 movaps(mem(rbx, -7*16), xmm3) addpd(xmm4, xmm15) @@ -929,13 +601,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -6*16), xmm2) addpd(xmm4, xmm13) @@ -943,7 +615,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -951,9 +623,9 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -6*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -5*16), xmm1) - - - + + + addpd(xmm3, xmm11) // iteration 1 movaps(mem(rbx, -5*16), xmm3) addpd(xmm4, xmm15) @@ -961,13 +633,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -4*16), xmm2) addpd(xmm4, xmm13) @@ -975,7 +647,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -983,16 +655,16 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -4*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -3*16), xmm1) - - + + prefetch(0, mem(rax, (4*37+1)*8)) //prefetch(0, mem(rax, (8*97+12)*8)) - + //prefetch(0, mem(r11, 69*4*8)) // prefetch a_next[8] //sub(imm(-4*4*8), r11) // a_next += 4*4 (unroll x mr) - - - + + + addpd(xmm3, xmm11) // iteration 2 movaps(mem(rbx, -3*16), xmm3) addpd(xmm4, xmm15) @@ -1000,13 +672,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -2*16), xmm2) addpd(xmm4, xmm13) @@ -1014,8 +686,8 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - - + + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1023,9 +695,9 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -2*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -1*16), xmm1) - - - + + + addpd(xmm3, xmm11) // iteration 3 movaps(mem(rbx, -1*16), xmm3) addpd(xmm4, xmm15) @@ -1033,17 +705,17 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + sub(imm(0-4*4*8), rax) // a += 4*4 (unroll x mr) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + sub(imm(0-4*4*8), r9) // b_next += 4*4 (unroll x nr) - + addpd(xmm2, xmm9) movaps(mem(rbx, 0*16), xmm2) addpd(xmm4, xmm13) @@ -1051,9 +723,9 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + sub(imm(0-4*4*8), rbx) // b += 4*4 (unroll x nr) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1061,29 +733,29 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -8*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -7*16), xmm1) - + prefetch(2, mem(r9, 0*8)) // prefetch b_next[0] prefetch(2, mem(r9, 8*8)) // prefetch b_next[8] - + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - + + + //prefetch(2, mem(r9, -8*8)) // prefetch b_next[-8] - - - + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + addpd(xmm3, xmm11) // iteration 0 movaps(mem(rbx, -7*16), xmm3) addpd(xmm4, xmm15) @@ -1091,13 +763,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -6*16), xmm2) addpd(xmm4, xmm13) @@ -1105,7 +777,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1113,38 +785,38 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -6*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -5*16), xmm1) - - + + sub(imm(0-4*1*8), rax) // a += 4 (1 x mr) sub(imm(0-4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + addpd(xmm3, xmm11) addpd(xmm4, xmm15) addpd(xmm5, xmm10) addpd(xmm6, xmm14) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta movddup(mem(rax), xmm6) // load alpha and duplicate movddup(mem(rbx), xmm7) // load beta and duplicate - - + + mov(var(rs_c), rsi) // load rs_c mov(rsi, r8) // make a copy of rs_c - + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - + lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - + // xmm8: xmm9: xmm10: xmm11: // ( ab01 ( ab00 ( ab03 ( ab02 // ab10 ) ab11 ) ab12 ) ab13 ) @@ -1155,15 +827,15 @@ void bli_dgemm_penryn_asm_4x4 movaps(xmm8, xmm0) movsd(xmm9, xmm8) movsd(xmm0, xmm9) - + movaps(xmm10, xmm0) movsd(xmm11, xmm10) movsd(xmm0, xmm11) - + movaps(xmm12, xmm0) movsd(xmm13, xmm12) movsd(xmm0, xmm13) - + movaps(xmm14, xmm0) movsd(xmm15, xmm14) movsd(xmm0, xmm15) @@ -1174,313 +846,133 @@ void bli_dgemm_penryn_asm_4x4 // xmm12: xmm13: xmm14: xmm15: // ( ab20 ( ab21 ( ab22 ( ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - - - - // determine if - // c % 16 == 0, AND - // 8*cs_c % 16 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(1), r8) // set ZF if rs_c == 1. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(15), rcx) // set ZF if c & 16 is zero. - setz(bh) // bh = ( ZF == 1 ? 1 : 0 ); - test(imm(15), r12) // set ZF if (8*cs_c) & 16 is zero. - setz(al) // al = ( ZF == 1 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + xorpd(xmm0, xmm0) // set xmm0 to zero. ucomisd(xmm0, xmm7) // check if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - movlpd(mem(rcx), xmm0) // load c00 and c10, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c20 and c30, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c01 and c11, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c21 and c31, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c02 and c12, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c22 and c32, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c03 and c13, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - - - movlpd(mem(rdx), xmm1) // load c23 and c33, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - movaps(mem(rcx), xmm0) // load c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - - - movaps(mem(rdx), xmm1) // load c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - - jmp(.DDONE) // jump to end. - - - - + + movaps(mem(rcx), xmm0) // load c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + + + movaps(mem(rdx), xmm1) // load c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movlpd(xmm8, mem(rcx)) // and store back to memory. - movhpd(xmm8, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movlpd(xmm12, mem(rdx)) // and store back to memory. - movhpd(xmm12, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movlpd(xmm9, mem(rcx)) // and store back to memory. - movhpd(xmm9, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movlpd(xmm13, mem(rdx)) // and store back to memory. - movhpd(xmm13, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movlpd(xmm10, mem(rcx)) // and store back to memory. - movhpd(xmm10, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movlpd(xmm14, mem(rdx)) // and store back to memory. - movhpd(xmm14, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movlpd(xmm11, mem(rcx)) // and store back to memory. - movhpd(xmm11, mem(rcx, rsi, 1)) - - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movlpd(xmm15, mem(rdx)) // and store back to memory. - movhpd(xmm15, mem(rdx, rsi, 1)) - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. - - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. - - - - - - - - + + // skip loading c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. + + // skip loading c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. + label(.DDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "xmm0", "xmm1", "xmm2", "xmm3", @@ -1489,6 +981,8 @@ void bli_dgemm_penryn_asm_4x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c index 6739e262f7..7bef618faf 100644 --- a/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c @@ -56,6 +56,8 @@ void bli_sgemmtrsm_l_penryn_asm_8x4 void bli_dgemmtrsm_l_penryn_asm_4x4 ( + dim_t m, + dim_t n, dim_t k0, double* restrict alpha, double* restrict a10, @@ -76,6 +78,8 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMMTRSM_UKR_SETUP_CT( d, 4, 4, false ); + begin_asm() mov(var(a10), rax) // load address of a10. @@ -415,8 +419,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 movddup(mem(0+0*4)*8(rax), xmm0) // load xmm0 = (1/alpha00) - mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); mulpd(xmm0, xmm12) // xmm12 *= (1/alpha00); +#else + divpd(xmm0, xmm8) // xmm8 /= alpha00; + divpd(xmm0, xmm12) // xmm12 /= alpha00; +#endif movaps(xmm8, mem(rbx, 0*16)) // store ( beta00 beta01 ) = xmm8 movaps(xmm12, mem(rbx, 1*16)) // store ( beta02 beta03 ) = xmm12 @@ -439,8 +448,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 mulpd(xmm12, xmm4) // xmm4 = alpha10 * ( beta02 beta03 ) subpd(xmm0, xmm9) // xmm9 -= xmm0 subpd(xmm4, xmm13) // xmm13 -= xmm4 - mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); mulpd(xmm1, xmm13) // xmm13 *= (1/alpha11); +#else + divpd(xmm1, xmm9) // xmm9 /= alpha11; + divpd(xmm1, xmm13) // xmm13 /= alpha11; +#endif movaps(xmm9, mem(rbx, 2*16)) // store ( beta10 beta11 ) = xmm9 movaps(xmm13, mem(rbx, 3*16)) // store ( beta12 beta13 ) = xmm13 @@ -469,8 +483,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 addpd(xmm5, xmm4) // xmm4 += xmm5; subpd(xmm0, xmm10) // xmm10 -= xmm0 subpd(xmm4, xmm14) // xmm14 -= xmm4 +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm2, xmm10) // xmm10 *= (1/alpha22); mulpd(xmm2, xmm14) // xmm14 *= (1/alpha22); +#else + divpd(xmm2, xmm10) // xmm10 /= alpha22; + divpd(xmm2, xmm14) // xmm14 /= alpha22; +#endif movaps(xmm10, mem(rbx, 4*16)) // store ( beta20 beta21 ) = xmm10 movaps(xmm14, mem(rbx, 5*16)) // store ( beta22 beta23 ) = xmm14 @@ -505,8 +524,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 addpd(xmm6, xmm4) // xmm4 += xmm6; subpd(xmm0, xmm11) // xmm11 -= xmm0 subpd(xmm4, xmm15) // xmm15 -= xmm4 +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm3, xmm11) // xmm11 *= (1/alpha33); mulpd(xmm3, xmm15) // xmm15 *= (1/alpha33); +#else + divpd(xmm3, xmm11) // xmm11 /= alpha33; + divpd(xmm3, xmm15) // xmm15 /= alpha33; +#endif movaps(xmm11, mem(rbx, 6*16)) // store ( beta30 beta31 ) = xmm11 movaps(xmm15, mem(rbx, 7*16)) // store ( beta32 beta33 ) = xmm15 @@ -541,6 +565,7 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 "memory" ) + GEMMTRSM_UKR_FLUSH_CT( d ); } diff --git a/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c index 5c355aac8d..add12ea244 100644 --- a/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c @@ -56,6 +56,8 @@ void bli_sgemmtrsm_u_penryn_asm_8x4 void bli_dgemmtrsm_u_penryn_asm_4x4 ( + dim_t m, + dim_t n, dim_t k0, double* restrict alpha, double* restrict a12, @@ -76,6 +78,8 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMMTRSM_UKR_SETUP_CT( d, 4, 4, false ); + begin_asm() mov(var(a12), rax) // load address of a12. @@ -401,8 +405,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 movddup(mem(3+3*4)*8(rax), xmm3) // load xmm3 = (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm3, xmm11) // xmm11 *= (1/alpha33); mulpd(xmm3, xmm15) // xmm15 *= (1/alpha33); +#else + divpd(xmm3, xmm11) // xmm11 /= alpha33; + divpd(xmm3, xmm15) // xmm15 /= alpha33; +#endif movaps(xmm11, mem(rbx, 6*16)) // store ( beta30 beta31 ) = xmm11 movaps(xmm15, mem(rbx, 7*16)) // store ( beta32 beta33 ) = xmm15 @@ -425,8 +434,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 mulpd(xmm15, xmm7) // xmm7 = alpha23 * ( beta32 beta33 ) subpd(xmm3, xmm10) // xmm10 -= xmm3 subpd(xmm7, xmm14) // xmm14 -= xmm7 +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm2, xmm10) // xmm10 *= (1/alpha22); mulpd(xmm2, xmm14) // xmm14 *= (1/alpha22); +#else + divpd(xmm2, xmm10) // xmm10 /= alpha22; + divpd(xmm2, xmm14) // xmm14 /= alpha22; +#endif movaps(xmm10, mem(rbx, 4*16)) // store ( beta20 beta21 ) = xmm10 movaps(xmm14, mem(rbx, 5*16)) // store ( beta22 beta23 ) = xmm14 @@ -455,8 +469,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 addpd(xmm7, xmm6) // xmm6 += xmm7; subpd(xmm2, xmm9) // xmm9 -= xmm2 subpd(xmm6, xmm13) // xmm13 -= xmm6 - mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); mulpd(xmm1, xmm13) // xmm13 *= (1/alpha11); +#else + divpd(xmm1, xmm9) // xmm9 /= alpha11; + divpd(xmm1, xmm13) // xmm13 /= alpha11; +#endif movaps(xmm9, mem(rbx, 2*16)) // store ( beta10 beta11 ) = xmm9 movaps(xmm13, mem(rbx, 3*16)) // store ( beta12 beta13 ) = xmm13 @@ -491,8 +510,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 addpd(xmm7, xmm5) // xmm5 += xmm7; subpd(xmm1, xmm8) // xmm8 -= xmm1 subpd(xmm5, xmm12) // xmm12 -= xmm5 - mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); mulpd(xmm0, xmm12) // xmm12 *= (1/alpha00); +#else + divpd(xmm0, xmm8) // xmm8 /= alpha00; + divpd(xmm0, xmm12) // xmm12 /= alpha00; +#endif movaps(xmm8, mem(rbx, 0*16)) // store ( beta00 beta01 ) = xmm8 movaps(xmm12, mem(rbx, 1*16)) // store ( beta02 beta03 ) = xmm12 @@ -526,6 +550,7 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 "memory" ) + GEMMTRSM_UKR_FLUSH_CT( d ); } diff --git a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c index 5963dabee6..e65ce7178a 100644 --- a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c +++ b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c @@ -42,7 +42,9 @@ void bli_sgemm_piledriver_asm_16x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -57,36 +59,38 @@ void bli_sgemm_piledriver_asm_16x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 16, 3, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + prefetch(0, mem(rbx, 128)) // prefetch b prefetch(0, mem(rbx, 64+128)) // prefetch b prefetch(0, mem(rbx, 128+128)) // prefetch b - + add(imm(32*4), rax) add(imm(12*4), rbx) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; lea(mem(rcx, rdi, 2), r11) // load address of c + 2*cs_c; - + vbroadcastss(mem(rbx, -12*4), xmm1) vbroadcastss(mem(rbx, -11*4), xmm2) vbroadcastss(mem(rbx, -10*4), xmm3) - + vxorps(xmm4, xmm4, xmm4) vxorps(xmm5, xmm5, xmm5) vxorps(xmm6, xmm6, xmm6) @@ -99,23 +103,23 @@ void bli_sgemm_piledriver_asm_16x3 vxorps(xmm13, xmm13, xmm13) vxorps(xmm14, xmm14, xmm14) vxorps(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + je(.SCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 16+192)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) prefetch(0, mem(rax, 384)) @@ -136,7 +140,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -8*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 1 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, -7*4), xmm3) @@ -158,7 +162,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -5*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 2 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, -4*4), xmm3) @@ -180,7 +184,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -2*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 3 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, -1*4), xmm3) @@ -202,10 +206,10 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 1*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - - + + add(imm(4*16*4), rax) // a += 4*16 (unroll x mr) - + // iteration 4 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, 2*4), xmm3) @@ -227,9 +231,9 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 4*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + prefetch(0, mem(rbx, 80+192)) // prefetch b - + // iteration 5 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, 5*4), xmm3) @@ -251,7 +255,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 7*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 6 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, 8*4), xmm3) @@ -273,7 +277,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 10*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 7 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, 11*4), xmm3) @@ -298,34 +302,34 @@ void bli_sgemm_piledriver_asm_16x3 vbroadcastss(mem(rbx, -11*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) vbroadcastss(mem(rbx, -10*4), xmm3) - - - - + + + + dec(rsi) // i -= 1; jmp(.SLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - - + + je(.SPOSTACCUM) // if i == 0, we're done. - - + + prefetch(0, mem(rbx, 16+192)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) prefetch(0, mem(rax, 384)) @@ -347,56 +351,56 @@ void bli_sgemm_piledriver_asm_16x3 vbroadcastss(mem(rbx, -8*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) vbroadcastss(mem(rbx, -7*4), xmm3) - - + + add(imm(1*16*4), rax) // a += 4*16 (unroll x mr) add(imm(1*3*4), rbx) // a += 4*3 (unroll x nr) - - + + dec(rsi) // i -= 1; jmp(.SLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.SPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c prefetchw0(mem(r11, 0*8)) // prefetch c + 2*cs_c - - - // xmm4: xmm5: xmm6: + + + // xmm4: xmm5: xmm6: // ( ab00 ( ab01 ( ab02 - // ab10 ab11 ab12 + // ab10 ab11 ab12 // ab20 ab21 ab22 // ab30 ) ab31 ) ab32 ) - - // xmm7: xmm8: xmm9: + + // xmm7: xmm8: xmm9: // ( ab40 ( ab41 ( ab42 - // ab50 ab51 ab52 + // ab50 ab51 ab52 // ab60 ab61 ab62 // ab70 ) ab71 ) ab72 ) - + // xmm10: xmm11: xmm12: // ( ab80 ( ab01 ( ab02 - // ab90 ab11 ab12 + // ab90 ab11 ab12 // abA0 abA1 abA2 // abB0 ) abB1 ) abB2 ) - + // xmm13: xmm14: xmm15: // ( abC0 ( abC1 ( abC2 - // abD0 abD1 abD2 + // abD0 abD1 abD2 // abE0 abE1 abE2 // abF0 ) abF1 ) abF2 ) - - - + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), xmm0) // load alpha and duplicate vbroadcastss(mem(rbx), xmm2) // load beta and duplicate - + vmulps(xmm0, xmm4, xmm4) // scale by alpha vmulps(xmm0, xmm5, xmm5) vmulps(xmm0, xmm6, xmm6) @@ -409,32 +413,32 @@ void bli_sgemm_piledriver_asm_16x3 vmulps(xmm0, xmm13, xmm13) vmulps(xmm0, xmm14, xmm14) vmulps(xmm0, xmm15, xmm15) - - - + + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - - - + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + + + // determine if // c % 32 == 0, AND // 4*cs_c % 32 == 0, AND // rs_c == 1 // ie: aligned, ldim aligned, and // column-stored - + cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); test(imm(31), rcx) // set ZF if c & 32 is zero. @@ -443,465 +447,69 @@ void bli_sgemm_piledriver_asm_16x3 setz(al) // al = ( ZF == 0 ? 1 : 0 ); // and(bl,bh) followed by // and(bh,al) will reveal result - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - + // now avoid loading C if beta == 0 - + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomiss(xmm0, xmm2) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - vmovlps(mem(rcx), xmm0, xmm0) // load c00:c30 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm4, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c00:c30 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load c40:c70 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm7, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c40:c70 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load c80:cB0 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c80:cB0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load cC0:cF0 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm13, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store cC0:cF0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c01:c31 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm5, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c01:c31 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c41:c71 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c41:c71 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c81:cB1 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm11, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c81:cB1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load cC1:cF1 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm14, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store cC1:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c02:c32 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm6, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c02:c32 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c42:c72 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm9, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c42:c72 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c82:cB2 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm12, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c82:cB2 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load cC2:cF2 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm15, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store cC2:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) - vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) - vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) - vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) - vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) - vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) - vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - jmp(.SDONE) // jump to end. - - - + + vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) + vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) + vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) + vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) + vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) + vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) + vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) + + // fall through + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(xmm4, xmm0) - vmovss(xmm0, mem(rcx)) // store c00:c30 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm7, xmm0) - vmovss(xmm0, mem(rcx)) // store c40:c70 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm10, xmm0) - vmovss(xmm0, mem(rcx)) // store c80:cB0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm13, xmm0) - vmovss(xmm0, mem(rcx)) // store cC0:cF0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm5, xmm0) - vmovss(xmm0, mem(r10)) // store c01:c31 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm8, xmm0) - vmovss(xmm0, mem(r10)) // store c41:c71 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm11, xmm0) - vmovss(xmm0, mem(r10)) // store c81:cB1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm14, xmm0) - vmovss(xmm0, mem(r10)) // store cC1:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm6, xmm0) - vmovss(xmm0, mem(r11)) // store c02:c32 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm9, xmm0) - vmovss(xmm0, mem(r11)) // store c42:c72 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm12, xmm0) - vmovss(xmm0, mem(r11)) // store c82:cB2 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm15, xmm0) - vmovss(xmm0, mem(r11)) // store cC2:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - - - + + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) + + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) + + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) + label(.SDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -909,11 +517,15 @@ void bli_sgemm_piledriver_asm_16x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_piledriver_asm_8x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -928,36 +540,38 @@ void bli_dgemm_piledriver_asm_8x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 3, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + prefetch(0, mem(rbx, 128)) // prefetch b prefetch(0, mem(rbx, 64+128)) // prefetch b prefetch(0, mem(rbx, 128+128)) // prefetch b - + add(imm(16*8), rax) add(imm(12*8), rbx) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; lea(mem(rcx, rdi, 2), r11) // load address of c + 2*cs_c; - + vmovddup(mem(rbx, -12*8), xmm1) vmovddup(mem(rbx, -11*8), xmm2) vmovddup(mem(rbx, -10*8), xmm3) - + vxorpd(xmm4, xmm4, xmm4) vxorpd(xmm5, xmm5, xmm5) vxorpd(xmm6, xmm6, xmm6) @@ -970,24 +584,24 @@ void bli_dgemm_piledriver_asm_8x3 vxorpd(xmm13, xmm13, xmm13) vxorpd(xmm14, xmm14, xmm14) vxorpd(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + je(.DCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, -32+256)) // prefetch b prefetch(0, mem(rbx, 32+256)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 384)) // prefetch a @@ -1008,7 +622,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -8*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 1 vmovaps(mem(rax, -4*16), xmm0) prefetch(0, mem(rax, 64+384)) // prefetch a @@ -1030,7 +644,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -5*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 2 vmovaps(mem(rax, 0*16), xmm0) prefetch(0, mem(rax, 128+384)) // prefetch a @@ -1052,7 +666,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -2*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 3 vmovaps(mem(rax, 4*16), xmm0) prefetch(0, mem(rax, 192+384)) // prefetch a @@ -1075,7 +689,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 1*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 4 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 384)) // prefetch a @@ -1097,9 +711,9 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 4*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + prefetch(0, mem(rbx, 96+256)) // prefetch b - + // iteration 5 vmovaps(mem(rax, -4*16), xmm0) prefetch(0, mem(rax, 64+384)) // prefetch a @@ -1121,8 +735,8 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 7*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - - + + // iteration 6 vmovaps(mem(rax, 0*16), xmm0) prefetch(0, mem(rax, 128+384)) // prefetch a @@ -1144,7 +758,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 10*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 7 vmovaps(mem(rax, 4*16), xmm0) prefetch(0, mem(rax, 192+384)) // prefetch a @@ -1169,31 +783,31 @@ void bli_dgemm_piledriver_asm_8x3 vmovddup(mem(rbx, -11*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) vmovddup(mem(rbx, -10*8), xmm3) - - - + + + dec(rsi) // i -= 1; jmp(.DLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done. // else, we prepare to // enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - - + + je(.DPOSTACCUM) // if i == 0, we're done. - + // iteration 0 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 512)) // prefetch a @@ -1215,48 +829,48 @@ void bli_dgemm_piledriver_asm_8x3 vmovddup(mem(rbx, -8*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) vmovddup(mem(rbx, -7*8), xmm3) - - + + add(imm(1*8*8), rax) // a += 1*8 (1 x mr) add(imm(1*3*8), rbx) // b += 1*3 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.DLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.DPOSTACCUM) - + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c prefetchw0(mem(r11, 0*8)) // prefetch c + 2*cs_c - - - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 + + + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 // ab10 ) ab11 ) ab12 ) // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 // ab30 ) ab31 ) ab32 ) // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 // ab50 ) ab51 ) ab52 ) // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 // ab70 ) ab71 ) ab72 ) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vmovddup(mem(rax), xmm0) // load alpha and duplicate vmovddup(mem(rbx), xmm2) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(xmm0, xmm6, xmm6) @@ -1269,358 +883,89 @@ void bli_dgemm_piledriver_asm_8x3 vmulpd(xmm0, xmm13, xmm13) vmulpd(xmm0, xmm14, xmm14) vmulpd(xmm0, xmm15, xmm15) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - + // now avoid loading C if beta == 0 - + vxorpd(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - je(.DGENSTORED) // jump to column storage case - - - - label(.DCOLSTORED) - - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 - // ab10 ) ab11 ) ab12 ) - // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 - // ab30 ) ab31 ) ab32 ) - // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 - // ab50 ) ab51 ) ab52 ) - // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 - // ab70 ) ab71 ) ab72 ) - - - vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) - - vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) - vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) - vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) - vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) - - vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) - vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) - vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) - vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) - - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - -/* - vmovupd(mem(rcx), xmm0) // load c00:c10 - vmovupd(mem(rcx, r12, 1), xmm1) // load c20:c30 - vfmadd231pd(xmm2, xmm0, xmm4) - vfmadd231pd(xmm2, xmm1, xmm7) - vmovupd(xmm4, mem(rcx)) // store c00:c10 - vmovupd(xmm7, mem(rcx, r12, 1)) // store c20:c30 - add(rdi, rcx) - - vmovupd(mem(rdx), xmm0) // load c40:c50 - vmovupd(mem(rdx, r12, 1), xmm1) // load c60:c70 - vfmadd213pd(xmm10, xmm2, xmm0) - vfmadd213pd(xmm13, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c40:c50 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c60:c70 - add(rdi, rdx) - - - vmovupd(mem(rcx), xmm0) // load c01:c11 - vmovupd(mem(rcx, r12, 1), xmm1) // load c21:c31 - vfmadd213pd(xmm5, xmm2, xmm0) - vfmadd213pd(xmm8, xmm2, xmm1) - vmovupd(xmm0, mem(rcx)) // store c01:c11 - vmovupd(xmm1, mem(rcx, r12, 1)) // store c21:c31 - add(rdi, rcx) - - vmovupd(mem(rdx), xmm0) // load c41:c51 - vmovupd(mem(rdx, r12, 1), xmm1) // load c61:c71 - vfmadd213pd(xmm11, xmm2, xmm0) - vfmadd213pd(xmm14, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c41:c51 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c61:c71 - add(rdi, rdx) - - - vmovupd(mem(rcx), xmm0) // load c02:c12 - vmovupd(mem(rcx, r12, 1), xmm1) // load c22:c32 - vfmadd213pd(xmm6, xmm2, xmm0) - vfmadd213pd(xmm9, xmm2, xmm1) - vmovupd(xmm0, mem(rcx)) // store c02:c12 - vmovupd(xmm1, mem(rcx, r12, 1)) // store c22:c32 - - vmovupd(mem(rdx), xmm0) // load c42:c52 - vmovupd(mem(rdx, r12, 1), xmm1) // load c62:c72 - vfmadd213pd(xmm12, xmm2, xmm0) - vfmadd213pd(xmm15, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c42:c52 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c62:c72 -*/ - - - - jmp(.DDONE) // jump to end. - - - - label(.DGENSTORED) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c00:c10 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm4, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c00:c10 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c20:c30 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm7, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c20:c30 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c40:c50 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c40:c50 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c60:c70 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm13, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c60:c70 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c01:c11 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm5, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c01:c11 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c21:c31 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c21:c31 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c41:c51 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm11, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c41:c51 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c61:c71 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm14, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c61:c71 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c02:c12 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm6, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c02:c12 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c22:c32 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm9, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c22:c32 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c42:c52 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm12, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c42:c52 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c62:c72 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm15, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c62:c72 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - + + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 + // ab10 ) ab11 ) ab12 ) + // + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 + // ab30 ) ab31 ) ab32 ) + // + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 + // ab50 ) ab51 ) ab52 ) + // + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 + // ab70 ) ab71 ) ab72 ) + + vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) + vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) + vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) + vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) + vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) + vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) + vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) + + // fall through + label(.DBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovlpd(xmm4, mem(rcx)) - vmovhpd(xmm4, mem(rcx, rsi, 1)) - vmovlpd(xmm7, mem(rcx, r12, 1)) - vmovhpd(xmm7, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm10, mem(rdx)) - vmovhpd(xmm10, mem(rdx, rsi, 1)) - vmovlpd(xmm13, mem(rdx, r12, 1)) - vmovhpd(xmm13, mem(rdx, r13, 1)) - add(rdi, rdx) - - vmovlpd(xmm5, mem(rcx)) - vmovhpd(xmm5, mem(rcx, rsi, 1)) - vmovlpd(xmm8, mem(rcx, r12, 1)) - vmovhpd(xmm8, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm11, mem(rdx)) - vmovhpd(xmm11, mem(rdx, rsi, 1)) - vmovlpd(xmm14, mem(rdx, r12, 1)) - vmovhpd(xmm14, mem(rdx, r13, 1)) - add(rdi, rdx) - - vmovlpd(xmm6, mem(rcx)) - vmovhpd(xmm6, mem(rcx, rsi, 1)) - vmovlpd(xmm9, mem(rcx, r12, 1)) - vmovhpd(xmm9, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm12, mem(rdx)) - vmovhpd(xmm12, mem(rdx, rsi, 1)) - vmovlpd(xmm15, mem(rdx, r12, 1)) - vmovhpd(xmm15, mem(rdx, r13, 1)) - add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vmovupd(xmm4, mem(rcx)) - vmovupd(xmm7, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm10, mem(rdx)) - vmovupd(xmm13, mem(rdx, r12, 1)) - add(rdi, rdx) - - vmovupd(xmm5, mem(rcx)) - vmovupd(xmm8, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm11, mem(rdx)) - vmovupd(xmm14, mem(rdx, r12, 1)) - add(rdi, rdx) - - vmovupd(xmm6, mem(rcx)) - vmovupd(xmm9, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm12, mem(rdx)) - vmovupd(xmm15, mem(rdx, r12, 1)) - add(rdi, rdx) - - - - - + + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) + + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) + + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) + label(.DDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1628,11 +973,15 @@ void bli_dgemm_piledriver_asm_8x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_piledriver_asm_4x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1647,28 +996,30 @@ void bli_cgemm_piledriver_asm_4x2 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 4, 2, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; - + add(imm(32*4), rax) add(imm(16*4), rbx) - - + + vxorps(xmm8, xmm8, xmm8) vxorps(xmm9, xmm9, xmm9) vxorps(xmm10, xmm10, xmm10) @@ -1678,24 +1029,24 @@ void bli_cgemm_piledriver_asm_4x2 vxorps(xmm14, xmm14, xmm14) vxorps(xmm15, xmm15, xmm15) //vzeroall() - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + je(.CCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 256)) prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, -16*4), xmm4) @@ -1711,7 +1062,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -13*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 1 vmovaps(mem(rax, -24*4), xmm0) vbroadcastss(mem(rbx, -12*4), xmm4) @@ -1727,10 +1078,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -9*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 64+256)) prefetch(0, mem(rax, 64+512)) - + // iteration 2 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, -8*4), xmm4) @@ -1746,7 +1097,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -5*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 3 vmovaps(mem(rax, -8*4), xmm0) vbroadcastss(mem(rbx, -4*4), xmm4) @@ -1762,10 +1113,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -1*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) prefetch(0, mem(rax, 128+512)) - + // iteration 4 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, 0*4), xmm4) @@ -1781,7 +1132,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 3*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 5 vmovaps(mem(rax, 8*4), xmm0) vbroadcastss(mem(rbx, 4*4), xmm4) @@ -1797,10 +1148,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 7*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) prefetch(0, mem(rax, 128+512)) - + // iteration 6 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, 8*4), xmm4) @@ -1816,7 +1167,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 11*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 7 vmovaps(mem(rax, 24*4), xmm0) vbroadcastss(mem(rbx, 12*4), xmm4) @@ -1834,33 +1185,33 @@ void bli_cgemm_piledriver_asm_4x2 add(imm(8*2*8), rbx) // b += 8*2 (unroll x nr) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - - - + + + dec(rsi) // i -= 1; jmp(.CLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - - + + je(.CPOSTACCUM) // if i == 0, we're done. - + prefetch(0, mem(rbx, 256)) prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, -16*4), xmm4) @@ -1876,123 +1227,88 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -13*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - - + + add(imm(1*4*8), rax) // a += 1*2 (1 x mr) add(imm(1*2*8), rbx) // b += 1*2 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.CLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.CPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c - - + + vpermilps(imm(0xb1), xmm9, xmm9) vpermilps(imm(0xb1), xmm11, xmm11) vpermilps(imm(0xb1), xmm13, xmm13) vpermilps(imm(0xb1), xmm15, xmm15) - + vaddsubps(xmm9, xmm8, xmm8) vaddsubps(xmm11, xmm10, xmm10) vaddsubps(xmm13, xmm12, xmm12) vaddsubps(xmm15, xmm14, xmm14) - - + + // xmm8: xmm10: // ( ab00 ( ab01 // ab10 ab11 // ab20 ab21 // ab30 ) ab31 ) - + // xmm12: xmm14: // ( ab40 ( ab41 // ab50 ab51 // ab60 ab61 // ab70 ) ab71 ) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), xmm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), xmm1) // load alpha_i and duplicate - + vpermilps(imm(0xb1), xmm8, xmm9) vpermilps(imm(0xb1), xmm10, xmm11) vpermilps(imm(0xb1), xmm12, xmm13) vpermilps(imm(0xb1), xmm14, xmm15) - + vmulps(xmm8, xmm0, xmm8) vmulps(xmm10, xmm0, xmm10) vmulps(xmm12, xmm0, xmm12) vmulps(xmm14, xmm0, xmm14) - + vmulps(xmm9, xmm1, xmm9) vmulps(xmm11, xmm1, xmm11) vmulps(xmm13, xmm1, xmm13) vmulps(xmm15, xmm1, xmm15) - + vaddsubps(xmm9, xmm8, xmm8) vaddsubps(xmm11, xmm10, xmm10) vaddsubps(xmm13, xmm12, xmm12) vaddsubps(xmm15, xmm14, xmm14) - - - - + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), xmm6) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), xmm7) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomiss(xmm0, xmm6) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2000,175 +1316,66 @@ void bli_cgemm_piledriver_asm_4x2 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - - vmovlps(mem(rcx), xmm0, xmm0) // load c00:c10 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm2, xmm2) // load c20:c30 - vmovhps(mem(rcx, r13, 1), xmm2, xmm2) - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovlps(xmm0, mem(rcx)) // store c00:c10 - vmovhps(xmm0, mem(rcx, rsi, 1)) - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) - vmovlps(xmm2, mem(rcx, r12, 1)) // store c20:c30 - vmovhps(xmm2, mem(rcx, r13, 1)) - - - - vmovlps(mem(r10), xmm0, xmm0) // load c01:c11 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm2, xmm2) // load c21:c31 - vmovhps(mem(r10, r13, 1), xmm2, xmm2) - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovlps(xmm0, mem(r10)) // store c01:c11 - vmovhps(xmm0, mem(r10, rsi, 1)) - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) - vmovlps(xmm2, mem(r10, r12, 1)) // store c21:c31 - vmovhps(xmm2, mem(r10, r13, 1)) - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - - vmovups(mem(rcx), xmm0) // load c00:c10 - vmovups(mem(rcx, 16), xmm2) // load c20:c30 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00:c10 - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, 16)) // store c20:c30 - - - - vmovups(mem(r10), xmm0) // load c01:c11 - vmovups(mem(r10, 16), xmm2) // load c21:c31 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01:c11 - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, 16)) // store c21:c31 - - - - jmp(.CDONE) // jump to end. - - - + + vmovups(mem(rcx), xmm0) // load c00:c10 + vmovups(mem(rcx, 16), xmm2) // load c20:c30 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) + + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm8, xmm0, xmm0) + + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm12, xmm2, xmm2) + + vmovups(mem(r10), xmm0) // load c01:c11 + vmovups(mem(r10, 16), xmm2) // load c21:c31 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) + + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm10, xmm0, xmm0) + + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm14, xmm2, xmm2) + + // fall through + label(.CBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORBZ) // jump to column storage case - - - - label(.CGENSTORBZ) - - - vmovlps(xmm8, mem(rcx)) // store c00:c10 - vmovhps(xmm8, mem(rcx, rsi, 1)) - - vmovlps(xmm12, mem(rcx, r12, 1)) // store c20:c30 - vmovhps(xmm12, mem(rcx, r13, 1)) - - vmovlps(xmm10, mem(r10)) // store c01:c11 - vmovhps(xmm10, mem(r10, rsi, 1)) - - vmovlps(xmm14, mem(r10, r12, 1)) // store c21:c31 - vmovhps(xmm14, mem(r10, r13, 1)) - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00:c10 - vmovups(xmm12, mem(rcx, 16)) // store c20:c30 - - vmovups(xmm10, mem(r10)) // store c01:c11 - vmovups(xmm14, mem(r10, 16)) // store c21:c31 - - - - - + + vmovups(xmm8, mem(rcx)) // store c00:c10 + vmovups(xmm12, mem(rcx, 16)) // store c20:c30 + + vmovups(xmm10, mem(r10)) // store c01:c11 + vmovups(xmm14, mem(r10, 16)) // store c21:c31 + label(.CDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2176,11 +1383,15 @@ void bli_cgemm_piledriver_asm_4x2 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_piledriver_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2195,28 +1406,30 @@ void bli_zgemm_piledriver_asm_2x2 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 2, 2, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; - + add(imm(16*8), rax) add(imm(16*8), rbx) - + vxorpd(xmm8, xmm8, xmm8) vxorpd(xmm9, xmm9, xmm9) vxorpd(xmm10, xmm10, xmm10) @@ -2225,25 +1438,25 @@ void bli_zgemm_piledriver_asm_2x2 vxorpd(xmm13, xmm13, xmm13) vxorpd(xmm14, xmm14, xmm14) vxorpd(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + je(.ZCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 256)) - + prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -16*8), xmm0) vmovddup(mem(rbx, -16*8), xmm4) @@ -2261,7 +1474,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -12*8), xmm0) vmovddup(mem(rbx, -12*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 1 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -10*8), xmm1) @@ -2277,11 +1490,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -8*8), xmm0) vmovddup(mem(rbx, -8*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 64+256)) - + prefetch(0, mem(rax, 64+512)) - + // iteration 2 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -6*8), xmm1) @@ -2297,7 +1510,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -4*8), xmm0) vmovddup(mem(rbx, -4*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 3 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -2*8), xmm1) @@ -2313,11 +1526,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 0*8), xmm0) vmovddup(mem(rbx, 0*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) - + prefetch(0, mem(rax, 128+512)) - + // iteration 4 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 2*8), xmm1) @@ -2333,7 +1546,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 4*8), xmm0) vmovddup(mem(rbx, 4*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 5 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 6*8), xmm1) @@ -2349,11 +1562,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 8*8), xmm0) vmovddup(mem(rbx, 8*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) - + prefetch(0, mem(rax, 128+512)) - + // iteration 6 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 10*8), xmm1) @@ -2369,7 +1582,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 12*8), xmm0) vmovddup(mem(rbx, 12*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 7 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 14*8), xmm1) @@ -2385,34 +1598,34 @@ void bli_zgemm_piledriver_asm_2x2 add(imm(8*2*16), rbx) // b += 8*2 (unroll x nr) vfmadd231pd(xmm0, xmm7, xmm11) vfmadd231pd(xmm1, xmm7, xmm15) - - - + + + dec(rsi) // i -= 1; jmp(.ZLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - - + + je(.ZPOSTACCUM) // if i == 0, we're done. - + prefetch(0, mem(rbx, 256)) - + prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -16*8), xmm0) vmovddup(mem(rbx, -16*8), xmm4) @@ -2428,119 +1641,86 @@ void bli_zgemm_piledriver_asm_2x2 vmovddup(mem(rbx, -13*8), xmm7) vfmadd231pd(xmm0, xmm7, xmm11) vfmadd231pd(xmm1, xmm7, xmm15) - - + + add(imm(1*2*16), rax) // a += 1*2 (1 x mr) add(imm(1*2*16), rbx) // b += 1*2 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.ZLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.ZPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c - - + + vpermilpd(imm(0x1), xmm9, xmm9) vpermilpd(imm(0x1), xmm11, xmm11) vpermilpd(imm(0x1), xmm13, xmm13) vpermilpd(imm(0x1), xmm15, xmm15) - + vaddsubpd(xmm9, xmm8, xmm8) vaddsubpd(xmm11, xmm10, xmm10) vaddsubpd(xmm13, xmm12, xmm12) vaddsubpd(xmm15, xmm14, xmm14) - - + + // xmm8: xmm10: // ( ab00 ( ab01 // ab10 ) ab11 ) - + // xmm12: xmm14: // ( ab20 ( ab21 // ab30 ) ab31 ) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vmovddup(mem(rax), xmm0) // load alpha_r and duplicate vmovddup(mem(rax, 8), xmm1) // load alpha_i and duplicate - + vpermilpd(imm(0x1), xmm8, xmm9) vpermilpd(imm(0x1), xmm10, xmm11) vpermilpd(imm(0x1), xmm12, xmm13) vpermilpd(imm(0x1), xmm14, xmm15) - + vmulpd(xmm8, xmm0, xmm8) vmulpd(xmm10, xmm0, xmm10) vmulpd(xmm12, xmm0, xmm12) vmulpd(xmm14, xmm0, xmm14) - + vmulpd(xmm9, xmm1, xmm9) vmulpd(xmm11, xmm1, xmm11) vmulpd(xmm13, xmm1, xmm13) vmulpd(xmm15, xmm1, xmm15) - + vaddsubpd(xmm9, xmm8, xmm8) vaddsubpd(xmm11, xmm10, xmm10) vaddsubpd(xmm13, xmm12, xmm12) vaddsubpd(xmm15, xmm14, xmm14) - - - - + + + + mov(var(beta), rbx) // load address of beta vmovddup(mem(rbx), xmm6) // load beta_r and duplicate vmovddup(mem(rbx, 8), xmm7) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - //lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - - - - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - - - - // determine if - // c % 32 == 0, AND - // 16*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (16*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + vxorpd(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomisd(xmm0, xmm6) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2548,161 +1728,66 @@ void bli_zgemm_piledriver_asm_2x2 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - - - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, rsi, 1), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, rsi, 1)) // store c10 - - - - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, rsi, 1), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, rsi, 1)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - - - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, 16), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, 16)) // store c10 - - - - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, 16), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, 16)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - + + vmovups(mem(rcx), xmm0) // load c00 + vmovups(mem(rcx, 16), xmm2) // load c10 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) + + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm8, xmm0, xmm0) + + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm12, xmm2, xmm2) + + vmovups(mem(r10), xmm0) // load c01 + vmovups(mem(r10, 16), xmm2) // load c11 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) + + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm10, xmm0, xmm0) + + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm14, xmm2, xmm2) + + // fall through + label(.ZBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, rsi, 1)) // store c10 - - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, rsi, 1)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, 16)) // store c10 - - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, 16)) // store c11 - - - - - + + vmovups(xmm8, mem(rcx)) // store c00 + vmovups(xmm12, mem(rcx, 16)) // store c10 + + vmovups(xmm10, mem(r10)) // store c01 + vmovups(xmm14, mem(r10, 16)) // store c11 + label(.ZDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2710,6 +1795,8 @@ void bli_zgemm_piledriver_asm_2x2 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/power10/3/bli_dgemm_power10_mma.c b/kernels/power10/3/bli_dgemm_power10_mma.c new file mode 100644 index 0000000000..84e7d16d34 --- /dev/null +++ b/kernels/power10/3/bli_dgemm_power10_mma.c @@ -0,0 +1,197 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#include "vector_int_macros.h" + +#define D_ASSEMBLE_VEC_PAIR \ + __builtin_mma_assemble_pair (&colA_1, ca[1], ca[0]); \ + __builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]); + +#define D_ACCUMULATE \ + __builtin_mma_xvf64gerpp (&acc0, colA_1, rb[0]); \ + __builtin_mma_xvf64gerpp (&acc1, colA_1, rb[1]); \ + __builtin_mma_xvf64gerpp (&acc2, colA_1, rb[2]); \ + __builtin_mma_xvf64gerpp (&acc3, colA_1, rb[3]); \ + __builtin_mma_xvf64gerpp (&acc4, colA_2, rb[0]); \ + __builtin_mma_xvf64gerpp (&acc5, colA_2, rb[1]); \ + __builtin_mma_xvf64gerpp (&acc6, colA_2, rb[2]); \ + __builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]); + +#define D_INCREMENT \ + A0+=8; \ + B0+=8; + +#define D_AB_PRODUCT \ + LOAD_VECTORS \ + D_ASSEMBLE_VEC_PAIR \ + D_INCREMENT \ + D_ACCUMULATE + + +void bli_dgemm_power10_mma_8x8 + ( + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + // (1 is subtracted from k0 because 1 iteration of the k loop is pulled out) + uint64_t k_iter = (k-1) / 4; + uint64_t k_left = (k-1) % 4; + + uint64_t rs_c = rs_c0; + + GEMM_UKR_SETUP_CT( d, 8, 8, true ); + + double* restrict A0 = a; + double* restrict B0 = b; + double* restrict C0 = c; + + double alpha_ = *alpha, + beta_ = *beta; + + dv4sf_t result[4]; + dv4sf_t *rowC; + + /* 8 accumulator registers that will be used to store the result. + + Each accumulator register is mapped to 4 vector registers. + Illustration: + + acc0 = [ vs0 + vs1 + vs3 + vs4 ] + + These registers are used to store the result of an outer product + instruction (general outer product instruction syntax: xv???ger??). */ + __vector_quad acc0, acc1, acc2, acc3, + acc4, acc5, acc6, acc7; + + /* 2 vector pairs are necessary for a double precision outer product + instruction. */ + __vector_pair colA_1, + colA_2; + + /* Prefetch C so that it stays in cache */ + PREFETCH1 (C0, 0); + PREFETCH1 (C0 + rs_c, 0); + PREFETCH1 (C0 + rs_c + rs_c, 0); + PREFETCH1 (C0 + rs_c + rs_c + rs_c, 0); + PREFETCH1 (C0, 128); + PREFETCH1 (C0 + rs_c, 128); + PREFETCH1 (C0 + rs_c + rs_c, 128); + PREFETCH1 (C0 + rs_c + rs_c + rs_c, 128); + + /* Load elements into vector registers */ + vec_t *ca = (vec_t *) A0; + vec_t *rb = (vec_t *) B0; + + /* Each accumulator represents a matrix of size + 4 x ( 16 / (datatype size in bytes) ) (vector register size = 16B) + + Thus in the case of double, the accumulate registers represent a 4x2 + matrix. However, a vector register can hold at most 2 doubles. Thus, if + we performed an outer product using 2 vector register, we can only get a + 2x2 matrix. Therefore, we must create a vector register pair in order + to get the desired 4x2 matrix. + + */ + D_ASSEMBLE_VEC_PAIR + + /* Compute accumulate outer products and override accumulators with result */ + __builtin_mma_xvf64ger (&acc0, colA_1, rb[0]); + __builtin_mma_xvf64ger (&acc1, colA_1, rb[1]); + __builtin_mma_xvf64ger (&acc2, colA_1, rb[2]); + __builtin_mma_xvf64ger (&acc3, colA_1, rb[3]); + __builtin_mma_xvf64ger (&acc4, colA_2, rb[0]); + __builtin_mma_xvf64ger (&acc5, colA_2, rb[1]); + __builtin_mma_xvf64ger (&acc6, colA_2, rb[2]); + __builtin_mma_xvf64ger (&acc7, colA_2, rb[3]); + + /* Move A and B pointers */ + D_INCREMENT + + // k loop (unrolled by 4) + for (int k = 0; k 0; kk--) { + vector double va00 = vec_splats( *(double *)( pa+0 ) ); + vector double va10 = vec_splats( *(double *)( pa+d1 ) ); + vector double va20 = vec_splats( *(double *)( pa+d2 ) ); + vector double va30 = vec_splats( *(double *)( pa+d3 ) ); + vector double va40 = vec_splats( *(double *)( pa+d4 ) ); + vector double va50 = vec_splats( *(double *)( pa+d5 ) ); + vector double va60 = vec_splats( *(double *)( pa+d6 ) ); + vector double va70 = vec_splats( *(double *)( pa+d7 ) ); + pa += 8*sizeof(double); + + vector double vb00_01 = *(vector double *)( pb+0 ); + vector double vb02_03 = *(vector double *)( pb+d2 ); + pb += 4*sizeof(double); + + vc00_01 = vec_madd(va00, vb00_01, vc00_01); + vc02_03 = vec_madd(va00, vb02_03, vc02_03); + vc10_11 = vec_madd(va10, vb00_01, vc10_11); + vc12_13 = vec_madd(va10, vb02_03, vc12_13); + vc20_21 = vec_madd(va20, vb00_01, vc20_21); + vc22_23 = vec_madd(va20, vb02_03, vc22_23); + vc30_31 = vec_madd(va30, vb00_01, vc30_31); + vc32_33 = vec_madd(va30, vb02_03, vc32_33); + vc40_41 = vec_madd(va40, vb00_01, vc40_41); + vc42_43 = vec_madd(va40, vb02_03, vc42_43); + vc50_51 = vec_madd(va50, vb00_01, vc50_51); + vc52_53 = vec_madd(va50, vb02_03, vc52_53); + vc60_61 = vec_madd(va60, vb00_01, vc60_61); + vc62_63 = vec_madd(va60, vb02_03, vc62_63); + vc70_71 = vec_madd(va70, vb00_01, vc70_71); + vc72_73 = vec_madd(va70, vb02_03, vc72_73); + } + + vector double valpha = vec_splats( *alpha ); + vector double vbeta = (vector double) { *beta, *beta }; + + vector double *pc = (vector double *)c; + + vc00_01 = vec_mul(valpha, vc00_01); + vc02_03 = vec_mul(valpha, vc02_03); + pc[0] = vec_madd( pc[0], vbeta, vc00_01); + pc[1] = vec_madd( pc[1], vbeta, vc02_03); + pc += rs_c/2; + + vc10_11 = vec_mul(valpha, vc10_11); + vc12_13 = vec_mul(valpha, vc12_13); + pc[0] = vec_madd( pc[0], vbeta, vc10_11); + pc[1] = vec_madd( pc[1], vbeta, vc12_13); + pc += rs_c/2; + + vc20_21 = vec_mul(valpha, vc20_21); + vc22_23 = vec_mul(valpha, vc22_23); + pc[0] = vec_madd( pc[0], vbeta, vc20_21); + pc[1] = vec_madd( pc[1], vbeta, vc22_23); + pc += rs_c/2; + + vc30_31 = vec_mul(valpha, vc30_31); + vc32_33 = vec_mul(valpha, vc32_33); + pc[0] = vec_madd( pc[0], vbeta, vc30_31); + pc[1] = vec_madd( pc[1], vbeta, vc32_33); + pc += rs_c/2; + + vc40_41 = vec_mul(valpha, vc40_41); + vc42_43 = vec_mul(valpha, vc42_43); + pc[0] = vec_madd( pc[0], vbeta, vc40_41); + pc[1] = vec_madd( pc[1], vbeta, vc42_43); + pc += rs_c/2; + + vc50_51 = vec_mul(valpha, vc50_51); + vc52_53 = vec_mul(valpha, vc52_53); + pc[0] = vec_madd( pc[0], vbeta, vc50_51); + pc[1] = vec_madd( pc[1], vbeta, vc52_53); + pc += rs_c/2; + + vc60_61 = vec_mul(valpha, vc60_61); + vc62_63 = vec_mul(valpha, vc62_63); + pc[0] = vec_madd( pc[0], vbeta, vc60_61); + pc[1] = vec_madd( pc[1], vbeta, vc62_63); + pc += rs_c/2; + + vc70_71 = vec_mul(valpha, vc70_71); + vc72_73 = vec_mul(valpha, vc72_73); + pc[0] = vec_madd( pc[0], vbeta, vc70_71); + pc[1] = vec_madd( pc[1], vbeta, vc72_73); + pc += rs_c/2; + } + else + { + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + // Optimized code for case where C columns are contiguous (column-major C) vector double vzero = vec_splats( 0.0 ); @@ -301,168 +433,8 @@ void bli_dgemm_power7_int_8x4 pc[1] = vec_madd( pc[1], vbeta, vc23_33); pc[2] = vec_madd( pc[2], vbeta, vc43_53); pc[3] = vec_madd( pc[3], vbeta, vc63_73); - } - else -#endif -#if 1 - if ( cs_c == 1 ) { - // Optimized code for case where C rows are contiguous (i.e. C is row-major) - - vector double vzero = vec_splats( 0.0 ); - - vector double vc00_01 = vzero; - vector double vc02_03 = vzero; - vector double vc10_11 = vzero; - vector double vc12_13 = vzero; - vector double vc20_21 = vzero; - vector double vc22_23 = vzero; - vector double vc30_31 = vzero; - vector double vc32_33 = vzero; - vector double vc40_41 = vzero; - vector double vc42_43 = vzero; - vector double vc50_51 = vzero; - vector double vc52_53 = vzero; - vector double vc60_61 = vzero; - vector double vc62_63 = vzero; - vector double vc70_71 = vzero; - vector double vc72_73 = vzero; - - unsigned long long pa = (unsigned long long)a; - unsigned long long pb = (unsigned long long)b; - -#if 0 - unsigned long long d1 = 1*sizeof(double); - unsigned long long d2 = 2*sizeof(double); - unsigned long long d3 = 3*sizeof(double); - unsigned long long d4 = 4*sizeof(double); - unsigned long long d6 = 6*sizeof(double); -#else - // ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables - register unsigned long long d1 __asm ("r21") = 1*sizeof(double); - register unsigned long long d2 __asm ("r22") = 2*sizeof(double); - register unsigned long long d3 __asm ("r23") = 3*sizeof(double); - register unsigned long long d4 __asm ("r24") = 4*sizeof(double); - register unsigned long long d5 __asm ("r25") = 5*sizeof(double); - register unsigned long long d6 __asm ("r26") = 6*sizeof(double); - register unsigned long long d7 __asm ("r27") = 7*sizeof(double); - - __asm__ volatile (";" : "=r" (d1) : "r" (d1) ); - __asm__ volatile (";" : "=r" (d2) : "r" (d2) ); - __asm__ volatile (";" : "=r" (d3) : "r" (d3) ); - __asm__ volatile (";" : "=r" (d4) : "r" (d4) ); - __asm__ volatile (";" : "=r" (d5) : "r" (d5) ); - __asm__ volatile (";" : "=r" (d6) : "r" (d6) ); - __asm__ volatile (";" : "=r" (d7) : "r" (d7) ); -#endif - - int kk; - for (kk=k; kk > 0; kk--) { - vector double va00 = vec_splats( *(double *)( pa+0 ) ); - vector double va10 = vec_splats( *(double *)( pa+d1 ) ); - vector double va20 = vec_splats( *(double *)( pa+d2 ) ); - vector double va30 = vec_splats( *(double *)( pa+d3 ) ); - vector double va40 = vec_splats( *(double *)( pa+d4 ) ); - vector double va50 = vec_splats( *(double *)( pa+d5 ) ); - vector double va60 = vec_splats( *(double *)( pa+d6 ) ); - vector double va70 = vec_splats( *(double *)( pa+d7 ) ); - pa += 8*sizeof(double); - - vector double vb00_01 = *(vector double *)( pb+0 ); - vector double vb02_03 = *(vector double *)( pb+d2 ); - pb += 4*sizeof(double); - - vc00_01 = vec_madd(va00, vb00_01, vc00_01); - vc02_03 = vec_madd(va00, vb02_03, vc02_03); - vc10_11 = vec_madd(va10, vb00_01, vc10_11); - vc12_13 = vec_madd(va10, vb02_03, vc12_13); - vc20_21 = vec_madd(va20, vb00_01, vc20_21); - vc22_23 = vec_madd(va20, vb02_03, vc22_23); - vc30_31 = vec_madd(va30, vb00_01, vc30_31); - vc32_33 = vec_madd(va30, vb02_03, vc32_33); - vc40_41 = vec_madd(va40, vb00_01, vc40_41); - vc42_43 = vec_madd(va40, vb02_03, vc42_43); - vc50_51 = vec_madd(va50, vb00_01, vc50_51); - vc52_53 = vec_madd(va50, vb02_03, vc52_53); - vc60_61 = vec_madd(va60, vb00_01, vc60_61); - vc62_63 = vec_madd(va60, vb02_03, vc62_63); - vc70_71 = vec_madd(va70, vb00_01, vc70_71); - vc72_73 = vec_madd(va70, vb02_03, vc72_73); - } - - vector double valpha = vec_splats( *alpha ); - vector double vbeta = (vector double) { *beta, *beta }; - - vector double *pc = (vector double *)c; - - vc00_01 = vec_mul(valpha, vc00_01); - vc02_03 = vec_mul(valpha, vc02_03); - pc[0] = vec_madd( pc[0], vbeta, vc00_01); - pc[1] = vec_madd( pc[1], vbeta, vc02_03); - pc += rs_c/2; - - vc10_11 = vec_mul(valpha, vc10_11); - vc12_13 = vec_mul(valpha, vc12_13); - pc[0] = vec_madd( pc[0], vbeta, vc10_11); - pc[1] = vec_madd( pc[1], vbeta, vc12_13); - pc += rs_c/2; - - vc20_21 = vec_mul(valpha, vc20_21); - vc22_23 = vec_mul(valpha, vc22_23); - pc[0] = vec_madd( pc[0], vbeta, vc20_21); - pc[1] = vec_madd( pc[1], vbeta, vc22_23); - pc += rs_c/2; - - vc30_31 = vec_mul(valpha, vc30_31); - vc32_33 = vec_mul(valpha, vc32_33); - pc[0] = vec_madd( pc[0], vbeta, vc30_31); - pc[1] = vec_madd( pc[1], vbeta, vc32_33); - pc += rs_c/2; - - vc40_41 = vec_mul(valpha, vc40_41); - vc42_43 = vec_mul(valpha, vc42_43); - pc[0] = vec_madd( pc[0], vbeta, vc40_41); - pc[1] = vec_madd( pc[1], vbeta, vc42_43); - pc += rs_c/2; - - vc50_51 = vec_mul(valpha, vc50_51); - vc52_53 = vec_mul(valpha, vc52_53); - pc[0] = vec_madd( pc[0], vbeta, vc50_51); - pc[1] = vec_madd( pc[1], vbeta, vc52_53); - pc += rs_c/2; - - vc60_61 = vec_mul(valpha, vc60_61); - vc62_63 = vec_mul(valpha, vc62_63); - pc[0] = vec_madd( pc[0], vbeta, vc60_61); - pc[1] = vec_madd( pc[1], vbeta, vc62_63); - pc += rs_c/2; - - vc70_71 = vec_mul(valpha, vc70_71); - vc72_73 = vec_mul(valpha, vc72_73); - pc[0] = vec_madd( pc[0], vbeta, vc70_71); - pc[1] = vec_madd( pc[1], vbeta, vc72_73); - pc += rs_c/2; - } - else -#endif - { /* General case. Just do it right. */ -#if 1 || defined(UTEST) - const long MR = BLIS_DEFAULT_MR_D, NR = BLIS_DEFAULT_NR_D; - const long LDA = MR, LDB = NR; - int i, j, kk; - double c00; - - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { - c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; - for (kk=0; kk < k; kk++) - c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]); - c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00; - } - } -#else - //BLIS_DGEMM_UKERNEL_REF(k, alpha, a, b, beta, c, rs_c, cs_c, data); -#endif + GEMM_UKR_FLUSH_CT( d ); } } @@ -477,30 +449,26 @@ void bli_dgemm_power7_int_8x4 */ void bli_cgemm_power7_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k = k0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - #if 1 || defined(UTEST) const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C; const long LDA = MR, LDB = NR; int i, j, kk; scomplex c00; - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { + for (i=0; i < m; i++) { + for (j=0; j < n; j++) { scomplex tmpc, tmpa, tmpb, tmp; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; @@ -534,30 +502,26 @@ void bli_cgemm_power7_int_8x4 */ void bli_zgemm_power7_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k = k0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - #if 1 || defined(UTEST) const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z; const long LDA = MR, LDB = NR; int i, j, kk; dcomplex c00; - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { + for (i=0; i < m; i++) { + for (j=0; j < n; j++) { dcomplex tmpc, tmpa, tmpb, tmp; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; diff --git a/kernels/power7/3/test/bli_gemm_power7_int_8x4.h b/kernels/power7/3/test/bli_gemm_power7_int_8x4.h index ef1930907e..50984a67df 100644 --- a/kernels/power7/3/test/bli_gemm_power7_int_8x4.h +++ b/kernels/power7/3/test/bli_gemm_power7_int_8x4.h @@ -43,6 +43,8 @@ void bli_sgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, float* restrict alpha, float* restrict a, @@ -55,6 +57,8 @@ void bli_sgemm_opt_8x4 void bli_dgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -67,6 +71,8 @@ void bli_dgemm_opt_8x4 void bli_cgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -79,6 +85,8 @@ void bli_cgemm_opt_8x4 void bli_zgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c new file mode 100644 index 0000000000..3e5f0d4164 --- /dev/null +++ b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c @@ -0,0 +1,179 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_pwr9_asm_macros_12x6.h" + +void bli_dgemm_power9_asm_12x6 + ( + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k / 16; + uint64_t k_left = k % 16; + + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + GEMM_UKR_SETUP_CT( d, 12, 6, false ); + + __asm__ volatile + ( + " \n\t" + "ld %%r7, %2 \n\t" // load ptr of A + "ld %%r8, %3 \n\t" // load ptr of B + "ld %%r16, %6 \n\t" // load ptr of C + " \n\t" + "ld %%r28, %4 \n\t" // load ptr for alpha + "ld %%r29, %5 \n\t" // load ptr for beta + " \n\t" + "ld %%r11, %0 \n\t" // load k_iter + "ld %%r12, %1 \n\t" // load k_left + " \n\t" + "ld %%r10, %8 \n\t" // load cs_c + "slwi %%r10, %%r10, 3 \n\t" // mul by size of elem + " \n\t" + "ld %%r9, %7 \n\t" // load rs_c + "slwi %%r9, %%r9, 3 \n\t" // mul by size of elem + " \n\t" + "ld %%r26, 0(%%r29) \n\t" // load val of beta + " \n\t" + "lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha + "lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta + " \n\t" + "add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C + "add %%r18, %%r17, %%r10 \n\t" // col 2 of C + "add %%r19, %%r18, %%r10 \n\t" // col 3 of C + "add %%r20, %%r19, %%r10 \n\t" // col 4 of C + "add %%r21, %%r20, %%r10 \n\t" // col 5 of C + " \n\t" + DZERO_OUT_VREG + " \n\t" + DPRELOAD + " \n\t" + "addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B + "addi %%r7, %%r7, 96 \n\t" + " \n\t" + DPREFETCH + " \n\t" + "cmpwi %%r11, 0 \n\t" // if k_iter == 0, + "beq DCONSIDERKLEFT \n\t" // then jmp to k_left + "mtctr %%r11 \n\t" // else, do k_iter loop + " \n\t" + "DLOOPKITER: \n\t" // k_iter loop + " \n\t" + A_B_PRODUCT_16 // compute A*B + " \n\t" + "bdnz DLOOPKITER \n\t" + " \n\t" + "DCONSIDERKLEFT: \n\t" + " \n\t" + "cmpwi %%r12, 0 \n\t" // if k_left == 0, + "beq DPOSTACCUM \n\t" // then jmp to post accum + "mtctr %%r12 \n\t" // else, do k_left loop + " \n\t" + "DLOOPKLEFT: \n\t" // k_left loop + " \n\t" + A_B_PRODUCT_1 + " \n\t" + "bdnz DLOOPKLEFT \n\t" + " \n\t" + "DPOSTACCUM: \n\t" + " \n\t" + DSCALE_ALPHA + " \n\t" + "cmpdi %%r26, 0 \n\t" // if beta == 0, + "beq DBETAZERO \n\t" // then jmp to BZ + " \n\t" + DCOL_SCALE_BETA + " \n\t" + "DBETAZERO: \n\t" // BZ case + " \n\t" + DCOL_STORE + " \n\t" + "DDONE: \n\t" + " \n\t" + : // output operands (none) + : // input operands + "m" (k_iter), // 0 + "m" (k_left), // 1 + "m" (a), // 2 + "m" (b), // 3 + "m" (alpha), // 4 + "m" (beta), // 5 + "m" (c), // 6 + "m" (rs_c), // 7 + "m" (cs_c)/*, // 8 + "m" (b_next), // 9 + "m" (a_next)*/ // 10 + : // register clobber list + /* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */ + "r0", "r7", "r8", "r9", + "r10", "r11", "r12", "r16", "r17", "r18", "r19", + "r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29" + + #if XLC + ,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9" + , "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19" + , "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29" + , "f30" ,"f31" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9" + , "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19" + , "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29" + , "v30", "v31" + #else + , "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9" + , "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19" + , "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29" + , "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39" + , "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49" + , "vs50", "vs51", "vs52", "vs53" + #endif + + ); + + GEMM_UKR_FLUSH_CT( d ); +} diff --git a/kernels/power9/3/bli_pwr9_asm_macros_12x6.h b/kernels/power9/3/bli_pwr9_asm_macros_12x6.h new file mode 100644 index 0000000000..8c4d256343 --- /dev/null +++ b/kernels/power9/3/bli_pwr9_asm_macros_12x6.h @@ -0,0 +1,1607 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// MACROS for power9_asm_d12x6 + + +// zero out registers used to store result +#define DZERO_OUT_VREG \ +"xxlxor %%vs0, %%vs0, %%vs0 \n\t" \ +"xxlxor %%vs1, %%vs1, %%vs1 \n\t" \ +"xxlxor %%vs2, %%vs2, %%vs2 \n\t" \ +"xxlxor %%vs3, %%vs3, %%vs3 \n\t" \ +"xxlxor %%vs4, %%vs4, %%vs4 \n\t" \ +"xxlxor %%vs5, %%vs5, %%vs5 \n\t" \ +"xxlxor %%vs6, %%vs6, %%vs6 \n\t" \ +"xxlxor %%vs7, %%vs7, %%vs7 \n\t" \ +"xxlxor %%vs8, %%vs8, %%vs8 \n\t" \ +"xxlxor %%vs9, %%vs9, %%vs9 \n\t" \ +"xxlxor %%vs10, %%vs10, %%vs10 \n\t" \ +"xxlxor %%vs11, %%vs11, %%vs11 \n\t" \ +"xxlxor %%vs12, %%vs12, %%vs12 \n\t" \ +"xxlxor %%vs13, %%vs13, %%vs13 \n\t" \ +"xxlxor %%vs14, %%vs14, %%vs14 \n\t" \ +"xxlxor %%vs15, %%vs15, %%vs15 \n\t" \ +"xxlxor %%vs16, %%vs16, %%vs16 \n\t" \ +"xxlxor %%vs17, %%vs17, %%vs17 \n\t" \ +"xxlxor %%vs18, %%vs18, %%vs18 \n\t" \ +"xxlxor %%vs19, %%vs19, %%vs19 \n\t" \ +"xxlxor %%vs20, %%vs20, %%vs20 \n\t" \ +"xxlxor %%vs21, %%vs21, %%vs21 \n\t" \ +"xxlxor %%vs22, %%vs22, %%vs22 \n\t" \ +"xxlxor %%vs23, %%vs23, %%vs23 \n\t" \ +"xxlxor %%vs24, %%vs24, %%vs24 \n\t" \ +"xxlxor %%vs25, %%vs25, %%vs25 \n\t" \ +"xxlxor %%vs26, %%vs26, %%vs26 \n\t" \ +"xxlxor %%vs27, %%vs27, %%vs27 \n\t" \ +"xxlxor %%vs28, %%vs28, %%vs28 \n\t" \ +"xxlxor %%vs29, %%vs29, %%vs29 \n\t" \ +"xxlxor %%vs30, %%vs30, %%vs30 \n\t" \ +"xxlxor %%vs31, %%vs31, %%vs31 \n\t" \ +"xxlxor %%vs32, %%vs32, %%vs32 \n\t" \ +"xxlxor %%vs33, %%vs33, %%vs33 \n\t" \ +"xxlxor %%vs34, %%vs34, %%vs34 \n\t" \ +"xxlxor %%vs35, %%vs35, %%vs35 \n\t" + +#define DPREFETCH \ +"dcbt 0, %%r16 \n\t" \ +"dcbt 0, %%r17 \n\t" \ +"dcbt 0, %%r18 \n\t" \ +"dcbt 0, %%r19 \n\t" \ +"dcbt 0, %%r20 \n\t" \ +"dcbt 0, %%r21 \n\t" + +// preload col/row of A/B +#define DPRELOAD \ +"lxv %%vs36, 0(%%r7) \n\t" \ +"lxv %%vs37, 16(%%r7) \n\t" \ +"lxv %%vs38, 32(%%r7) \n\t" \ +"lxv %%vs39, 48(%%r7) \n\t" \ +"lxv %%vs40, 64(%%r7) \n\t" \ +"lxv %%vs41, 80(%%r7) \n\t" \ +" \n\t" \ +"lxv %%vs48, 0(%%r8) \n\t" \ +"lxv %%vs49, 16(%%r8) \n\t" \ +"lxv %%vs50, 32(%%r8) \n\t" \ +"lxv %%vs51, 48(%%r8) \n\t" \ +"lxv %%vs52, 64(%%r8) \n\t" \ +"lxv %%vs53, 80(%%r8) \n\t" + +// compute AB product +// unrolled by 16 +#define A_B_PRODUCT_16 \ +" \n\t" \ +"lxv %%vs42, 0(%%r7) \n\t" \ +"lxv %%vs43, 16(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 0(%%r8) \n\t" \ +"lxv %%vs55, 16(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 32(%%r7) \n\t" \ +"lxv %%vs45, 48(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 32(%%r8) \n\t" \ +"lxv %%vs57, 48(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 64(%%r7) \n\t" \ +"lxv %%vs47, 80(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 64(%%r8) \n\t" \ +"lxv %%vs59, 80(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 96(%%r7) \n\t" \ +"lxv %%vs37, 112(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 96(%%r8) \n\t" \ +"lxv %%vs49, 112(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 128(%%r7) \n\t" \ +"lxv %%vs39, 144(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 128(%%r8) \n\t" \ +"lxv %%vs51, 144(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 160(%%r7) \n\t" \ +"lxv %%vs41, 176(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 160(%%r8) \n\t" \ +"lxv %%vs53, 176(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 192(%%r7) \n\t" \ +"lxv %%vs43, 208(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 192(%%r8) \n\t" \ +"lxv %%vs55, 208(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 224(%%r7) \n\t" \ +"lxv %%vs45, 240(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 224(%%r8) \n\t" \ +"lxv %%vs57, 240(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 256(%%r7) \n\t" \ +"lxv %%vs47, 272(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 256(%%r8) \n\t" \ +"lxv %%vs59, 272(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 288(%%r7) \n\t" \ +"lxv %%vs37, 304(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 288(%%r8) \n\t" \ +"lxv %%vs49, 304(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 320(%%r7) \n\t" \ +"lxv %%vs39, 336(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 320(%%r8) \n\t" \ +"lxv %%vs51, 336(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 352(%%r7) \n\t" \ +"lxv %%vs41, 368(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 352(%%r8) \n\t" \ +"lxv %%vs53, 368(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 384(%%r7) \n\t" \ +"lxv %%vs43, 400(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 384(%%r8) \n\t" \ +"lxv %%vs55, 400(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 416(%%r7) \n\t" \ +"lxv %%vs45, 432(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 416(%%r8) \n\t" \ +"lxv %%vs57, 432(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 448(%%r7) \n\t" \ +"lxv %%vs47, 464(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 448(%%r8) \n\t" \ +"lxv %%vs59, 464(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 480(%%r7) \n\t" \ +"lxv %%vs37, 496(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 480(%%r8) \n\t" \ +"lxv %%vs49, 496(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 512(%%r7) \n\t" \ +"lxv %%vs39, 528(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 512(%%r8) \n\t" \ +"lxv %%vs51, 528(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 544(%%r7) \n\t" \ +"lxv %%vs41, 560(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 544(%%r8) \n\t" \ +"lxv %%vs53, 560(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 576(%%r7) \n\t" \ +"lxv %%vs43, 592(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 576(%%r8) \n\t" \ +"lxv %%vs55, 592(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 608(%%r7) \n\t" \ +"lxv %%vs45, 624(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 608(%%r8) \n\t" \ +"lxv %%vs57, 624(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 640(%%r7) \n\t" \ +"lxv %%vs47, 656(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 640(%%r8) \n\t" \ +"lxv %%vs59, 656(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 672(%%r7) \n\t" \ +"lxv %%vs37, 688(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 672(%%r8) \n\t" \ +"lxv %%vs49, 688(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 704(%%r7) \n\t" \ +"lxv %%vs39, 720(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 704(%%r8) \n\t" \ +"lxv %%vs51, 720(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 736(%%r7) \n\t" \ +"lxv %%vs41, 752(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 736(%%r8) \n\t" \ +"lxv %%vs53, 752(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 768(%%r7) \n\t" \ +"lxv %%vs43, 784(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 768(%%r8) \n\t" \ +"lxv %%vs55, 784(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 800(%%r7) \n\t" \ +"lxv %%vs45, 816(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 800(%%r8) \n\t" \ +"lxv %%vs57, 816(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 832(%%r7) \n\t" \ +"lxv %%vs47, 848(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 832(%%r8) \n\t" \ +"lxv %%vs59, 848(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 864(%%r7) \n\t" \ +"lxv %%vs37, 880(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 864(%%r8) \n\t" \ +"lxv %%vs49, 880(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 896(%%r7) \n\t" \ +"lxv %%vs39, 912(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 896(%%r8) \n\t" \ +"lxv %%vs51, 912(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 928(%%r7) \n\t" \ +"lxv %%vs41, 944(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 928(%%r8) \n\t" \ +"lxv %%vs53, 944(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 960(%%r7) \n\t" \ +"lxv %%vs43, 976(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 960(%%r8) \n\t" \ +"lxv %%vs55, 976(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 992(%%r7) \n\t" \ +"lxv %%vs45, 1008(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 992(%%r8) \n\t" \ +"lxv %%vs57, 1008(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 1024(%%r7) \n\t" \ +"lxv %%vs47, 1040(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 1024(%%r8) \n\t" \ +"lxv %%vs59, 1040(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 1056(%%r7) \n\t" \ +"lxv %%vs37, 1072(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 1056(%%r8) \n\t" \ +"lxv %%vs49, 1072(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 1088(%%r7) \n\t" \ +"lxv %%vs39, 1104(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 1088(%%r8) \n\t" \ +"lxv %%vs51, 1104(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 1120(%%r7) \n\t" \ +"lxv %%vs41, 1136(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 1120(%%r8) \n\t" \ +"lxv %%vs53, 1136(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 1152(%%r7) \n\t" \ +"lxv %%vs43, 1168(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 1152(%%r8) \n\t" \ +"lxv %%vs55, 1168(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 1184(%%r7) \n\t" \ +"lxv %%vs45, 1200(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 1184(%%r8) \n\t" \ +"lxv %%vs57, 1200(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 1216(%%r7) \n\t" \ +"lxv %%vs47, 1232(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 1216(%%r8) \n\t" \ +"lxv %%vs59, 1232(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs36, 1248(%%r7) \n\t" \ +"lxv %%vs37, 1264(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"lxv %%vs48, 1248(%%r8) \n\t" \ +"lxv %%vs49, 1264(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"lxv %%vs38, 1280(%%r7) \n\t" \ +"lxv %%vs39, 1296(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"lxv %%vs50, 1280(%%r8) \n\t" \ +"lxv %%vs51, 1296(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"lxv %%vs40, 1312(%%r7) \n\t" \ +"lxv %%vs41, 1328(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"lxv %%vs52, 1312(%%r8) \n\t" \ +"lxv %%vs53, 1328(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"lxv %%vs42, 1344(%%r7) \n\t" \ +"lxv %%vs43, 1360(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +"lxv %%vs54, 1344(%%r8) \n\t" \ +"lxv %%vs55, 1360(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +"lxv %%vs44, 1376(%%r7) \n\t" \ +"lxv %%vs45, 1392(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +"lxv %%vs56, 1376(%%r8) \n\t" \ +"lxv %%vs57, 1392(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +"lxv %%vs46, 1408(%%r7) \n\t" \ +"lxv %%vs47, 1424(%%r7) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +"lxv %%vs58, 1408(%%r8) \n\t" \ +"lxv %%vs59, 1424(%%r8) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs42, %%vs54 \n\t" \ +"xvmaddadp %%vs1, %%vs43, %%vs54 \n\t" \ +"xvmaddadp %%vs2, %%vs44, %%vs54 \n\t" \ +"xvmaddadp %%vs3, %%vs45, %%vs54 \n\t" \ +"xvmaddadp %%vs4, %%vs46, %%vs54 \n\t" \ +"xvmaddadp %%vs5, %%vs47, %%vs54 \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs55 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs55 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs55 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs55 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs55 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs55 \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs42, %%vs56 \n\t" \ +"xvmaddadp %%vs13, %%vs43, %%vs56 \n\t" \ +"xvmaddadp %%vs14, %%vs44, %%vs56 \n\t" \ +"xvmaddadp %%vs15, %%vs45, %%vs56 \n\t" \ +"xvmaddadp %%vs16, %%vs46, %%vs56 \n\t" \ +"xvmaddadp %%vs17, %%vs47, %%vs56 \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs42, %%vs57 \n\t" \ +"xvmaddadp %%vs19, %%vs43, %%vs57 \n\t" \ +"xvmaddadp %%vs20, %%vs44, %%vs57 \n\t" \ +"xvmaddadp %%vs21, %%vs45, %%vs57 \n\t" \ +"xvmaddadp %%vs22, %%vs46, %%vs57 \n\t" \ +"xvmaddadp %%vs23, %%vs47, %%vs57 \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs58 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs58 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs58 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs58 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs58 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs58 \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs42, %%vs59 \n\t" \ +"xvmaddadp %%vs31, %%vs43, %%vs59 \n\t" \ +"xvmaddadp %%vs32, %%vs44, %%vs59 \n\t" \ +"xvmaddadp %%vs33, %%vs45, %%vs59 \n\t" \ +"xvmaddadp %%vs34, %%vs46, %%vs59 \n\t" \ +"xvmaddadp %%vs35, %%vs47, %%vs59 \n\t" \ +" \n\t" \ +"lxv %%vs36, 1440(%%r7) \n\t" \ +"lxv %%vs37, 1456(%%r7) \n\t" \ +" \n\t" \ +"lxv %%vs48, 1440(%%r8) \n\t" \ +"lxv %%vs49, 1456(%%r8) \n\t" \ +" \n\t" \ +"lxv %%vs38, 1472(%%r7) \n\t" \ +"lxv %%vs39, 1488(%%r7) \n\t" \ +" \n\t" \ +"lxv %%vs50, 1472(%%r8) \n\t" \ +"lxv %%vs51, 1488(%%r8) \n\t" \ +" \n\t" \ +"lxv %%vs40, 1504(%%r7) \n\t" \ +"lxv %%vs41, 1520(%%r7) \n\t" \ +" \n\t" \ +"lxv %%vs52, 1504(%%r8) \n\t" \ +"lxv %%vs53, 1520(%%r8) \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +" \n\t" \ +"addi %%r8, %%r8, 1536 \n\t" \ +"addi %%r7, %%r7, 1536 \n\t" + +// compute AB product +// no unrolling +#define A_B_PRODUCT_1 \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs48 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs48 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs48 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs48 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs48 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs48 \n\t" \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs49 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs49 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs49 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs49 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs49 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs49 \n\t" \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs50 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs50 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs50 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs50 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs50 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs50 \n\t" \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs51 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs51 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs51 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs51 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs51 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs51 \n\t" \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs52 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs52 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs52 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs52 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs52 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs52 \n\t" \ +" \n\t" \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs53 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs53 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs53 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs53 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs53 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs53 \n\t" \ +" \n\t" \ +"lxv %%vs48, 0(%%r8) \n\t" \ +"lxv %%vs49, 16(%%r8) \n\t" \ +"lxv %%vs50, 32(%%r8) \n\t" \ +"lxv %%vs51, 48(%%r8) \n\t" \ +"lxv %%vs52, 64(%%r8) \n\t" \ +"lxv %%vs53, 80(%%r8) \n\t" \ +" \n\t" \ +"lxv %%vs36, 0(%%r7) \n\t" \ +"lxv %%vs37, 16(%%r7) \n\t" \ +"lxv %%vs38, 32(%%r7) \n\t" \ +"lxv %%vs39, 48(%%r7) \n\t" \ +"lxv %%vs40, 64(%%r7) \n\t" \ +"lxv %%vs41, 80(%%r7) \n\t" \ +" \n\t" \ +"addi %%r8, %%r8, 96 \n\t" \ +"addi %%r7, %%r7, 96 \n\t" \ + +// scale AB product by alpha +#define DSCALE_ALPHA \ +"xvmuldp %%vs0, %%vs0, %%vs62 \n\t" \ +"xvmuldp %%vs1, %%vs1, %%vs62 \n\t" \ +"xvmuldp %%vs2, %%vs2, %%vs62 \n\t" \ +"xvmuldp %%vs3, %%vs3, %%vs62 \n\t" \ +"xvmuldp %%vs4, %%vs4, %%vs62 \n\t" \ +"xvmuldp %%vs5, %%vs5, %%vs62 \n\t" \ +"xvmuldp %%vs6, %%vs6, %%vs62 \n\t" \ +"xvmuldp %%vs7, %%vs7, %%vs62 \n\t" \ +"xvmuldp %%vs8, %%vs8, %%vs62 \n\t" \ +"xvmuldp %%vs9, %%vs9, %%vs62 \n\t" \ +"xvmuldp %%vs10, %%vs10, %%vs62 \n\t" \ +"xvmuldp %%vs11, %%vs11, %%vs62 \n\t" \ +"xvmuldp %%vs12, %%vs12, %%vs62 \n\t" \ +"xvmuldp %%vs13, %%vs13, %%vs62 \n\t" \ +"xvmuldp %%vs14, %%vs14, %%vs62 \n\t" \ +"xvmuldp %%vs15, %%vs15, %%vs62 \n\t" \ +"xvmuldp %%vs16, %%vs16, %%vs62 \n\t" \ +"xvmuldp %%vs17, %%vs17, %%vs62 \n\t" \ +"xvmuldp %%vs18, %%vs18, %%vs62 \n\t" \ +"xvmuldp %%vs19, %%vs19, %%vs62 \n\t" \ +"xvmuldp %%vs20, %%vs20, %%vs62 \n\t" \ +"xvmuldp %%vs21, %%vs21, %%vs62 \n\t" \ +"xvmuldp %%vs22, %%vs22, %%vs62 \n\t" \ +"xvmuldp %%vs23, %%vs23, %%vs62 \n\t" \ +"xvmuldp %%vs24, %%vs24, %%vs62 \n\t" \ +"xvmuldp %%vs25, %%vs25, %%vs62 \n\t" \ +"xvmuldp %%vs26, %%vs26, %%vs62 \n\t" \ +"xvmuldp %%vs27, %%vs27, %%vs62 \n\t" \ +"xvmuldp %%vs28, %%vs28, %%vs62 \n\t" \ +"xvmuldp %%vs29, %%vs29, %%vs62 \n\t" \ +"xvmuldp %%vs30, %%vs30, %%vs62 \n\t" \ +"xvmuldp %%vs31, %%vs31, %%vs62 \n\t" \ +"xvmuldp %%vs32, %%vs32, %%vs62 \n\t" \ +"xvmuldp %%vs33, %%vs33, %%vs62 \n\t" \ +"xvmuldp %%vs34, %%vs34, %%vs62 \n\t" \ +"xvmuldp %%vs35, %%vs35, %%vs62 \n\t" + +// initialize offset registers used for gen stored cases +#define DGEN_LOAD_OFS_C \ +"ld %%r22, %6 \n\t" \ +"slwi %%r12, %%r9, 1 \n\t" \ +"add %%r23, %%r22, %%r12 \n\t" \ +"add %%r24, %%r23, %%r12 \n\t" \ +"add %%r25, %%r24, %%r12 \n\t" \ +"add %%r26, %%r25, %%r12 \n\t" \ +"add %%r27, %%r26, %%r12 \n\t" + +// load C into registers +// assume C is gen stored +#define DGEN_LOAD_C \ +"lxsdx %%vs36, %%r9, %%r22 \n\t" \ +"lxsdx %%vs37, 0, %%r22 \n\t" \ +"xxpermdi %%vs36, %%vs36, %%vs37, 0 \n\t" \ +"lxsdx %%vs37, %%r9, %%r23 \n\t" \ +"lxsdx %%vs38, 0, %%r23 \n\t" \ +"xxpermdi %%vs37, %%vs37, %%vs38, 0 \n\t" \ +"lxsdx %%vs38, %%r9, %%r24 \n\t" \ +"lxsdx %%vs39, 0, %%r24 \n\t" \ +"xxpermdi %%vs38, %%vs38, %%vs39, 0 \n\t" \ +"lxsdx %%vs39, %%r9, %%r25 \n\t" \ +"lxsdx %%vs40, 0, %%r25 \n\t" \ +"xxpermdi %%vs39, %%vs39, %%vs40, 0 \n\t" \ +"lxsdx %%vs40, %%r9, %%r26 \n\t" \ +"lxsdx %%vs41, 0, %%r26 \n\t" \ +"xxpermdi %%vs40, %%vs40, %%vs41, 0 \n\t" \ +"lxsdx %%vs41, %%r9, %%r27 \n\t" \ +"lxsdx %%vs42, 0, %%r27 \n\t" \ +"xxpermdi %%vs41, %%vs41, %%vs42, 0 \n\t" + +// increment offset registers to the next col +#define DGEN_NEXT_COL_CMATRIX \ +"add %%r22, %%r22, %%r10 \n\t" \ +"add %%r23, %%r23, %%r10 \n\t" \ +"add %%r24, %%r24, %%r10 \n\t" \ +"add %%r25, %%r25, %%r10 \n\t" \ +"add %%r26, %%r26, %%r10 \n\t" \ +"add %%r27, %%r27, %%r10 \n\t" + +// load C into registers and move offset registers to next col +#define DGENLOAD_UPDATE \ +DGEN_LOAD_C \ +DGEN_NEXT_COL_CMATRIX + +// scale C by beta and add it to the AB product +// assume C is gen stored +#define DGEN_SCALE_BETA \ +DGENLOAD_UPDATE \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs63 \n\t" \ +" \n\t" \ +" \n\t" \ +DGENLOAD_UPDATE \ +" \n\t" \ +"xvmaddadp %%vs6, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs7, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs8, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs9, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs10, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs11, %%vs41, %%vs63 \n\t" \ +" \n\t" \ +" \n\t" \ +DGENLOAD_UPDATE \ +" \n\t" \ +"xvmaddadp %%vs12, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs13, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs14, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs15, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs16, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs17, %%vs41, %%vs63 \n\t" \ +" \n\t" \ +" \n\t" \ +DGENLOAD_UPDATE \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs63 \n\t" \ +" \n\t" \ +" \n\t" \ +DGENLOAD_UPDATE \ +" \n\t" \ +"xvmaddadp %%vs24, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs25, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs26, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs27, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs28, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs29, %%vs41, %%vs63 \n\t" \ +" \n\t" \ +" \n\t" \ +DGENLOAD_UPDATE \ +" \n\t" \ +"xvmaddadp %%vs30, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs31, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs32, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs33, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs34, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs35, %%vs41, %%vs63 \n\t" + +// scale C by beta and add it to the AB product +// assume C is col stored +#define DCOL_SCALE_BETA \ +"lxv %%vs36, 0(%%r16) \n\t" \ +"lxv %%vs42, 0(%%r17) \n\t" \ +"lxv %%vs48, 0(%%r18) \n\t" \ +"lxv %%vs41, 80(%%r16) \n\t" \ +"lxv %%vs47, 80(%%r17) \n\t" \ +"lxv %%vs53, 80(%%r18) \n\t" \ +"lxv %%vs37, 16(%%r16) \n\t" \ +"lxv %%vs38, 32(%%r16) \n\t" \ +"lxv %%vs39, 48(%%r16) \n\t" \ +"lxv %%vs40, 64(%%r16) \n\t" \ +"lxv %%vs43, 16(%%r17) \n\t" \ +"lxv %%vs44, 32(%%r17) \n\t" \ +"lxv %%vs45, 48(%%r17) \n\t" \ +"lxv %%vs46, 64(%%r17) \n\t" \ +"lxv %%vs49, 16(%%r18) \n\t" \ +"lxv %%vs50, 32(%%r18) \n\t" \ +"lxv %%vs51, 48(%%r18) \n\t" \ +"lxv %%vs52, 64(%%r18) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs0, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs6, %%vs42, %%vs63 \n\t" \ +"xvmaddadp %%vs12, %%vs48, %%vs63 \n\t" \ +"xvmaddadp %%vs5, %%vs41, %%vs63 \n\t" \ +"xvmaddadp %%vs11, %%vs47, %%vs63 \n\t" \ +"xvmaddadp %%vs17, %%vs53, %%vs63 \n\t" \ +"xvmaddadp %%vs1, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs2, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs3, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs4, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs7, %%vs43, %%vs63 \n\t" \ +"xvmaddadp %%vs8, %%vs44, %%vs63 \n\t" \ +"xvmaddadp %%vs9, %%vs45, %%vs63 \n\t" \ +"xvmaddadp %%vs10, %%vs46, %%vs63 \n\t" \ +"xvmaddadp %%vs13, %%vs49, %%vs63 \n\t" \ +"xvmaddadp %%vs14, %%vs50, %%vs63 \n\t" \ +"xvmaddadp %%vs15, %%vs51, %%vs63 \n\t" \ +"xvmaddadp %%vs16, %%vs52, %%vs63 \n\t" \ +" \n\t" \ +"lxv %%vs36, 0(%%r19) \n\t" \ +"lxv %%vs42, 0(%%r20) \n\t" \ +"lxv %%vs48, 0(%%r21) \n\t" \ +"lxv %%vs41, 80(%%r19) \n\t" \ +"lxv %%vs47, 80(%%r20) \n\t" \ +"lxv %%vs53, 80(%%r21) \n\t" \ +"lxv %%vs37, 16(%%r19) \n\t" \ +"lxv %%vs38, 32(%%r19) \n\t" \ +"lxv %%vs39, 48(%%r19) \n\t" \ +"lxv %%vs40, 64(%%r19) \n\t" \ +"lxv %%vs43, 16(%%r20) \n\t" \ +"lxv %%vs44, 32(%%r20) \n\t" \ +"lxv %%vs45, 48(%%r20) \n\t" \ +"lxv %%vs46, 64(%%r20) \n\t" \ +"lxv %%vs49, 16(%%r21) \n\t" \ +"lxv %%vs50, 32(%%r21) \n\t" \ +"lxv %%vs51, 48(%%r21) \n\t" \ +"lxv %%vs52, 64(%%r21) \n\t" \ +" \n\t" \ +"xvmaddadp %%vs18, %%vs36, %%vs63 \n\t" \ +"xvmaddadp %%vs24, %%vs42, %%vs63 \n\t" \ +"xvmaddadp %%vs30, %%vs48, %%vs63 \n\t" \ +"xvmaddadp %%vs23, %%vs41, %%vs63 \n\t" \ +"xvmaddadp %%vs29, %%vs47, %%vs63 \n\t" \ +"xvmaddadp %%vs35, %%vs53, %%vs63 \n\t" \ +"xvmaddadp %%vs19, %%vs37, %%vs63 \n\t" \ +"xvmaddadp %%vs20, %%vs38, %%vs63 \n\t" \ +"xvmaddadp %%vs21, %%vs39, %%vs63 \n\t" \ +"xvmaddadp %%vs22, %%vs40, %%vs63 \n\t" \ +"xvmaddadp %%vs25, %%vs43, %%vs63 \n\t" \ +"xvmaddadp %%vs26, %%vs44, %%vs63 \n\t" \ +"xvmaddadp %%vs27, %%vs45, %%vs63 \n\t" \ +"xvmaddadp %%vs28, %%vs46, %%vs63 \n\t" \ +"xvmaddadp %%vs31, %%vs49, %%vs63 \n\t" \ +"xvmaddadp %%vs32, %%vs50, %%vs63 \n\t" \ +"xvmaddadp %%vs33, %%vs51, %%vs63 \n\t" \ +"xvmaddadp %%vs34, %%vs52, %%vs63 \n\t" + +// store result into C's memory location +// assume C is gen stored +#define DGEN_STORE \ +" \n\t" \ +"stxsdx %%vs0, %%r9, %%r22 \n\t" \ +"xxswapd %%vs0, %%vs0 \n\t" \ +"stxsdx %%vs0, 0, %%r22 \n\t" \ +"stxsdx %%vs1, %%r9, %%r23 \n\t" \ +"xxswapd %%vs1, %%vs1 \n\t" \ +"stxsdx %%vs1, 0, %%r23 \n\t" \ +"stxsdx %%vs2, %%r9, %%r24 \n\t" \ +"xxswapd %%vs2, %%vs2 \n\t" \ +"stxsdx %%vs2, 0, %%r24 \n\t" \ +"stxsdx %%vs3, %%r9, %%r25 \n\t" \ +"xxswapd %%vs3, %%vs3 \n\t" \ +"stxsdx %%vs3, 0, %%r25 \n\t" \ +"stxsdx %%vs4, %%r9, %%r26 \n\t" \ +"xxswapd %%vs4, %%vs4 \n\t" \ +"stxsdx %%vs4, 0, %%r26 \n\t" \ +"stxsdx %%vs5, %%r9, %%r27 \n\t" \ +"xxswapd %%vs5, %%vs5 \n\t" \ +"stxsdx %%vs5, 0, %%r27 \n\t" \ +" \n\t" \ +DGEN_NEXT_COL_CMATRIX \ +" \n\t" \ +"stxsdx %%vs6, %%r9, %%r22 \n\t" \ +"xxswapd %%vs6, %%vs6 \n\t" \ +"stxsdx %%vs6, 0, %%r22 \n\t" \ +"stxsdx %%vs7, %%r9, %%r23 \n\t" \ +"xxswapd %%vs7, %%vs7 \n\t" \ +"stxsdx %%vs7, 0, %%r23 \n\t" \ +"stxsdx %%vs8, %%r9, %%r24 \n\t" \ +"xxswapd %%vs8, %%vs8 \n\t" \ +"stxsdx %%vs8, 0, %%r24 \n\t" \ +"stxsdx %%vs9, %%r9, %%r25 \n\t" \ +"xxswapd %%vs9, %%vs9 \n\t" \ +"stxsdx %%vs9, 0, %%r25 \n\t" \ +"stxsdx %%vs10, %%r9, %%r26 \n\t" \ +"xxswapd %%vs10, %%vs10 \n\t" \ +"stxsdx %%vs10, 0, %%r26 \n\t" \ +"stxsdx %%vs11, %%r9, %%r27 \n\t" \ +"xxswapd %%vs11, %%vs11 \n\t" \ +"stxsdx %%vs11, 0, %%r27 \n\t" \ +" \n\t" \ +DGEN_NEXT_COL_CMATRIX \ +" \n\t" \ +"stxsdx %%vs12, %%r9, %%r22 \n\t" \ +"xxswapd %%vs12, %%vs12 \n\t" \ +"stxsdx %%vs12, 0, %%r22 \n\t" \ +"stxsdx %%vs13, %%r9, %%r23 \n\t" \ +"xxswapd %%vs13, %%vs13 \n\t" \ +"stxsdx %%vs13, 0, %%r23 \n\t" \ +"stxsdx %%vs14, %%r9, %%r24 \n\t" \ +"xxswapd %%vs14, %%vs14 \n\t" \ +"stxsdx %%vs14, 0, %%r24 \n\t" \ +"stxsdx %%vs15, %%r9, %%r25 \n\t" \ +"xxswapd %%vs15, %%vs15 \n\t" \ +"stxsdx %%vs15, 0, %%r25 \n\t" \ +"stxsdx %%vs16, %%r9, %%r26 \n\t" \ +"xxswapd %%vs16, %%vs16 \n\t" \ +"stxsdx %%vs16, 0, %%r26 \n\t" \ +"stxsdx %%vs17, %%r9, %%r27 \n\t" \ +"xxswapd %%vs17, %%vs17 \n\t" \ +"stxsdx %%vs17, 0, %%r27 \n\t" \ +" \n\t" \ +DGEN_NEXT_COL_CMATRIX \ +" \n\t" \ +"stxsdx %%vs18, %%r9, %%r22 \n\t" \ +"xxswapd %%vs18, %%vs18 \n\t" \ +"stxsdx %%vs18, 0, %%r22 \n\t" \ +"stxsdx %%vs19, %%r9, %%r23 \n\t" \ +"xxswapd %%vs19, %%vs19 \n\t" \ +"stxsdx %%vs19, 0, %%r23 \n\t" \ +"stxsdx %%vs20, %%r9, %%r24 \n\t" \ +"xxswapd %%vs20, %%vs20 \n\t" \ +"stxsdx %%vs20, 0, %%r24 \n\t" \ +"stxsdx %%vs21, %%r9, %%r25 \n\t" \ +"xxswapd %%vs21, %%vs21 \n\t" \ +"stxsdx %%vs21, 0, %%r25 \n\t" \ +"stxsdx %%vs22, %%r9, %%r26 \n\t" \ +"xxswapd %%vs22, %%vs22 \n\t" \ +"stxsdx %%vs22, 0, %%r26 \n\t" \ +"stxsdx %%vs23, %%r9, %%r27 \n\t" \ +"xxswapd %%vs23, %%vs23 \n\t" \ +"stxsdx %%vs23, 0, %%r27 \n\t" \ +" \n\t" \ +DGEN_NEXT_COL_CMATRIX \ +" \n\t" \ +"stxsdx %%vs24, %%r9, %%r22 \n\t" \ +"xxswapd %%vs24, %%vs24 \n\t" \ +"stxsdx %%vs24, 0, %%r22 \n\t" \ +"stxsdx %%vs25, %%r9, %%r23 \n\t" \ +"xxswapd %%vs25, %%vs25 \n\t" \ +"stxsdx %%vs25, 0, %%r23 \n\t" \ +"stxsdx %%vs26, %%r9, %%r24 \n\t" \ +"xxswapd %%vs26, %%vs26 \n\t" \ +"stxsdx %%vs26, 0, %%r24 \n\t" \ +"stxsdx %%vs27, %%r9, %%r25 \n\t" \ +"xxswapd %%vs27, %%vs27 \n\t" \ +"stxsdx %%vs27, 0, %%r25 \n\t" \ +"stxsdx %%vs28, %%r9, %%r26 \n\t" \ +"xxswapd %%vs28, %%vs28 \n\t" \ +"stxsdx %%vs28, 0, %%r26 \n\t" \ +"stxsdx %%vs29, %%r9, %%r27 \n\t" \ +"xxswapd %%vs29, %%vs29 \n\t" \ +"stxsdx %%vs29, 0, %%r27 \n\t" \ +" \n\t" \ +DGEN_NEXT_COL_CMATRIX \ +" \n\t" \ +"stxsdx %%vs30, %%r9, %%r22 \n\t" \ +"xxswapd %%vs30, %%vs30 \n\t" \ +"stxsdx %%vs30, 0, %%r22 \n\t" \ +"stxsdx %%vs31, %%r9, %%r23 \n\t" \ +"xxswapd %%vs31, %%vs31 \n\t" \ +"stxsdx %%vs31, 0, %%r23 \n\t" \ +"stxsdx %%vs32, %%r9, %%r24 \n\t" \ +"xxswapd %%vs32, %%vs32 \n\t" \ +"stxsdx %%vs32, 0, %%r24 \n\t" \ +"stxsdx %%vs33, %%r9, %%r25 \n\t" \ +"xxswapd %%vs33, %%vs33 \n\t" \ +"stxsdx %%vs33, 0, %%r25 \n\t" \ +"stxsdx %%vs34, %%r9, %%r26 \n\t" \ +"xxswapd %%vs34, %%vs34 \n\t" \ +"stxsdx %%vs34, 0, %%r26 \n\t" \ +"stxsdx %%vs35, %%r9, %%r27 \n\t" \ +"xxswapd %%vs35, %%vs35 \n\t" \ +"stxsdx %%vs35, 0, %%r27 \n\t" + +// store result into C's memory location +// assume C is col stored +#define DCOL_STORE \ +"stxv %%vs0, 0(%%r16) \n\t" \ +"stxv %%vs1, 16(%%r16) \n\t" \ +"stxv %%vs2, 32(%%r16) \n\t" \ +"stxv %%vs3, 48(%%r16) \n\t" \ +"stxv %%vs4, 64(%%r16) \n\t" \ +"stxv %%vs5, 80(%%r16) \n\t" \ +"stxv %%vs6, 0(%%r17) \n\t" \ +"stxv %%vs7, 16(%%r17) \n\t" \ +"stxv %%vs8, 32(%%r17) \n\t" \ +"stxv %%vs9, 48(%%r17) \n\t" \ +"stxv %%vs10, 64(%%r17) \n\t" \ +"stxv %%vs11, 80(%%r17) \n\t" \ +"stxv %%vs12, 0(%%r18) \n\t" \ +"stxv %%vs13, 16(%%r18) \n\t" \ +"stxv %%vs14, 32(%%r18) \n\t" \ +"stxv %%vs15, 48(%%r18) \n\t" \ +"stxv %%vs16, 64(%%r18) \n\t" \ +"stxv %%vs17, 80(%%r18) \n\t" \ +"stxv %%vs18, 0(%%r19) \n\t" \ +"stxv %%vs19, 16(%%r19) \n\t" \ +"stxv %%vs20, 32(%%r19) \n\t" \ +"stxv %%vs21, 48(%%r19) \n\t" \ +"stxv %%vs22, 64(%%r19) \n\t" \ +"stxv %%vs23, 80(%%r19) \n\t" \ +"stxv %%vs24, 0(%%r20) \n\t" \ +"stxv %%vs25, 16(%%r20) \n\t" \ +"stxv %%vs26, 32(%%r20) \n\t" \ +"stxv %%vs27, 48(%%r20) \n\t" \ +"stxv %%vs28, 64(%%r20) \n\t" \ +"stxv %%vs29, 80(%%r20) \n\t" \ +"stxv %%vs30, 0(%%r21) \n\t" \ +"stxv %%vs31, 16(%%r21) \n\t" \ +"stxv %%vs32, 32(%%r21) \n\t" \ +"stxv %%vs33, 48(%%r21) \n\t" \ +"stxv %%vs34, 64(%%r21) \n\t" \ +"stxv %%vs35, 80(%%r21) \n\t" + diff --git a/frame/1m/packm/bli_packm_md.h b/kernels/power9/bli_kernels_power9.h similarity index 91% rename from frame/1m/packm/bli_packm_md.h rename to kernels/power9/bli_kernels_power9.h index bb9d6d6135..9f4d08ccb2 100644 --- a/frame/1m/packm/bli_packm_md.h +++ b/kernels/power9/bli_kernels_power9.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,6 +32,7 @@ */ -#include "bli_packm_blk_var1_md.h" -#include "bli_packm_struc_cxk_md.h" +// -- level-3 -- +// gemm (asm d12x6) +GEMM_UKR_PROT( double, d, gemm_power9_asm_12x6 ) \ No newline at end of file diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c index a56ef16e5e..7890ad347d 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c @@ -42,7 +42,9 @@ void bli_sgemm_sandybridge_asm_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -57,27 +59,29 @@ void bli_sgemm_sandybridge_asm_8x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 8, 8, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; - + lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c @@ -87,7 +91,7 @@ void bli_sgemm_sandybridge_asm_8x8 prefetch(0, mem(r10, rdi, 1, 7*8)) // prefetch c + 5*cs_c prefetch(0, mem(r10, rdi, 2, 7*8)) // prefetch c + 6*cs_c prefetch(0, mem(r10, r14, 1, 7*8)) // prefetch c + 7*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -96,18 +100,18 @@ void bli_sgemm_sandybridge_asm_8x8 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 16*32)) vmulps(ymm0, ymm2, ymm6) @@ -117,14 +121,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 1*32), ymm1) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) @@ -132,13 +136,13 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - + // iteration 1 vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -147,14 +151,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 2*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 2*32), ymm2) @@ -162,14 +166,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - + + // iteration 2 prefetch(0, mem(rax, 18*32)) vmulps(ymm0, ymm2, ymm6) @@ -179,7 +183,7 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 3*32), ymm1) add(imm(4*8*4), rax) // a += 4*8 (unroll x mr) vpermilps(imm(0x4e), ymm2, ymm3) @@ -187,7 +191,7 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 3*32), ymm2) @@ -195,14 +199,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - + + // iteration 3 vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -212,14 +216,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 0*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 0*32), ymm2) @@ -227,35 +231,35 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - - - + + + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - - + + prefetch(0, mem(rax, 16*32)) vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) @@ -264,7 +268,7 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 1*32), ymm1) add(imm(8*1*4), rax) // a += 8 (1 x mr) vpermilps(imm(0x4e), ymm2, ymm3) @@ -272,7 +276,7 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) @@ -281,122 +285,122 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(ymm1, ymm0) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - - + + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab22 ab20 ab26 ab24 // ab32 ab30 ab36 ab34 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab66 ab64 ab62 ab60 // ab76 ) ab74 ) ab72 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab23 ab21 ab27 ab25 // ab33 ab31 ab37 ab35 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab67 ab65 ab63 ab61 // ab77 ) ab75 ) ab73 ) ab71 ) - + vmovaps(ymm15, ymm7) vshufps(imm(0xe4), ymm13, ymm15, ymm15) vshufps(imm(0xe4), ymm7, ymm13, ymm13) - + vmovaps(ymm11, ymm7) vshufps(imm(0xe4), ymm9, ymm11, ymm11) vshufps(imm(0xe4), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vshufps(imm(0xe4), ymm12, ymm14, ymm14) vshufps(imm(0xe4), ymm7, ymm12, ymm12) - + vmovaps(ymm10, ymm7) vshufps(imm(0xe4), ymm8, ymm10, ymm10) vshufps(imm(0xe4), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab64 ab66 ab60 ab62 // ab74 ) ab76 ) ab70 ) ab72 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab65 ab67 ab61 ab63 // ab75 ) ab77 ) ab71 ) ab73 ) - + vmovaps(ymm15, ymm7) vperm2f128(imm(0x30), ymm11, ymm15, ymm15) vperm2f128(imm(0x12), ymm11, ymm7, ymm11) - + vmovaps(ymm13, ymm7) vperm2f128(imm(0x30), ymm9, ymm13, ymm13) vperm2f128(imm(0x12), ymm9, ymm7, ymm9) - + vmovaps(ymm14, ymm7) vperm2f128(imm(0x30), ymm10, ymm14, ymm14) vperm2f128(imm(0x12), ymm10, ymm7, ymm10) - + vmovaps(ymm12, ymm7) vperm2f128(imm(0x30), ymm8, ymm12, ymm12) vperm2f128(imm(0x12), ymm8, ymm7, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab40 ab42 ab44 ab46 - // ab50 ab52 ab54 ab56 + // ab50 ab52 ab54 ab56 // ab60 ab62 ab64 ab66 // ab70 ) ab72 ) ab74 ) ab76 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab41 ab43 ab45 ab47 - // ab51 ab53 ab55 ab57 + // ab51 ab53 ab55 ab57 // ab61 ab63 ab65 ab67 // ab71 ) ab73 ) ab75 ) ab77 ) - - - + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm4) // load beta and duplicate - + vmulps(ymm0, ymm8, ymm8) // scale by alpha vmulps(ymm0, ymm9, ymm9) vmulps(ymm0, ymm10, ymm10) @@ -405,618 +409,118 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - // update c00:c70 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm15, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c01:c71 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm14, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm13, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm12, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm11, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm10, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm9, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm8, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vmovups(mem(rcx), ymm0) // load c00:c70, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c01:c71, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c02:c72, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c03:c73, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c04:c74, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c05:c75, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c06:c76, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c07:c77, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - - - jmp(.SDONE) // jump to end. - - - - + + vmovups(mem(rcx), ymm0) // load c00:c70, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c01:c71, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c02:c72, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c03:c73, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c04:c74, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c05:c75, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c06:c76, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c07:c77, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - // update c00:c70 - vmovups(ymm15, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c01:c71 - vmovups(ymm14, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovups(ymm13, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovups(ymm12, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovups(ymm11, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovups(ymm10, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovups(ymm9, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovups(ymm8, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm8, mem(rcx)) // and store back to memory. - - - - - + + vmovups(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm8, mem(rcx)) // and store back to memory. + label(.SDONE) - + vzeroupper() - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1024,11 +528,15 @@ void bli_sgemm_sandybridge_asm_8x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_sandybridge_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -1043,34 +551,36 @@ void bli_dgemm_sandybridge_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovapd(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilpd(imm(0x5), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -1079,19 +589,19 @@ void bli_dgemm_sandybridge_asm_8x4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1100,7 +610,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 1*32), ymm2) @@ -1108,20 +618,20 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) prefetch(0, mem(r15, 0*32)) // prefetch b_next[0*4] - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1130,7 +640,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 18*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 2*32), ymm2) @@ -1138,19 +648,19 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1159,7 +669,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 20*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 3*32), ymm2) @@ -1168,20 +678,20 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) @@ -1191,7 +701,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + //prefetch(0, mem(rax, 22*32)) prefetch(0, mem(rax, 14*32)) vmulpd(ymm1, ymm2, ymm6) @@ -1200,41 +710,41 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 0*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - - + + + //add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) //add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + vmovapd(mem(rax, 1*32), ymm1) add(imm(8*1*8), rax) // a += 8 (1 x mr) vmulpd(ymm0, ymm2, ymm6) @@ -1243,7 +753,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 14*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 1*32), ymm2) @@ -1252,101 +762,101 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 0*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab11 ab10 ab13 ab12 + // ab11 ab10 ab13 ab12 // ab22 ab23 ab20 ab21 // ab33 ) ab32 ) ab31 ) ab30 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab51 ab50 ab53 ab52 + // ab51 ab50 ab53 ab52 // ab62 ab63 ab60 ab61 // ab73 ) ab72 ) ab71 ) ab70 ) - + vmovapd(ymm15, ymm7) vshufpd(imm(0xa), ymm15, ymm13, ymm15) vshufpd(imm(0xa), ymm13, ymm7, ymm13) - + vmovapd(ymm11, ymm7) vshufpd(imm(0xa), ymm11, ymm9, ymm11) vshufpd(imm(0xa), ymm9, ymm7, ymm9) - + vmovapd(ymm14, ymm7) vshufpd(imm(0xa), ymm14, ymm12, ymm14) vshufpd(imm(0xa), ymm12, ymm7, ymm12) - + vmovapd(ymm10, ymm7) vshufpd(imm(0xa), ymm10, ymm8, ymm10) vshufpd(imm(0xa), ymm8, ymm7, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab01 ( ab00 ( ab03 ( ab02 - // ab11 ab10 ab13 ab12 + // ab11 ab10 ab13 ab12 // ab23 ab22 ab21 ab20 // ab33 ) ab32 ) ab31 ) ab30 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab41 ( ab40 ( ab43 ( ab42 - // ab51 ab50 ab53 ab52 + // ab51 ab50 ab53 ab52 // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + vmovapd(ymm15, ymm7) vperm2f128(imm(0x30), ymm15, ymm11, ymm15) vperm2f128(imm(0x12), ymm7, ymm11, ymm11) - + vmovapd(ymm13, ymm7) vperm2f128(imm(0x30), ymm13, ymm9, ymm13) vperm2f128(imm(0x12), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x30), ymm14, ymm10, ymm14) vperm2f128(imm(0x12), ymm7, ymm10, ymm10) - + vmovapd(ymm12, ymm7) vperm2f128(imm(0x30), ymm12, ymm8, ymm12) vperm2f128(imm(0x12), ymm7, ymm8, ymm8) - + // ymm9: ymm11: ymm13: ymm15: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm8: ymm10: ymm12: ymm14: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm2) // load beta and duplicate - + vmulpd(ymm0, ymm8, ymm8) // scale by alpha vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm0, ymm10, ymm10) @@ -1355,343 +865,124 @@ void bli_dgemm_sandybridge_asm_8x4 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - // update c00:c33 - - vextractf128(imm(1), ymm9, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c00 and c10, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm9, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c20 and c30, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm11, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c01 and c11, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm11, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c21 and c31, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm13, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c02 and c12, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm13, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c22 and c32, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm15, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c03 and c13, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm15, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c23 and c33, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - - // update c40:c73 - - vextractf128(imm(1), ymm8, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c40 and c50, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm8, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c60 and c70, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm10, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c41 and c51, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm10, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c61 and c71, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm12, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c42 and c52, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm12, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c62 and c72, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm14, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c43 and c53, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm14, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c63 and c73, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - // update c00:c33 - - vmovupd(mem(rcx), ymm0) // load c00:c30, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm9, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c01:c31, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm11, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c02:c32, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm13, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c03:c33, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm15, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - - // update c40:c73 - - vmovupd(mem(rdx), ymm0) // load c40:c70, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm8, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c41:c71, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm10, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c42:c72, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm12, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c43:c73, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm14, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - - - jmp(.DDONE) // jump to end. - - - - + + // update c00:c33 + + vmovupd(mem(rcx), ymm0) // load c00:c30, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm9, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c01:c31, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm11, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c02:c32, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm13, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c03:c33, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm15, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + + // update c40:c73 + + vmovupd(mem(rdx), ymm0) // load c40:c70, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm8, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c41:c71, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm10, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c42:c72, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm12, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c43:c73, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm14, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - // update c00:c33 - - vextractf128(imm(1), ymm9, xmm1) - vmovlpd(xmm9, mem(rcx)) // store to c00:c30 - vmovhpd(xmm9, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm11, xmm1) - vmovlpd(xmm11, mem(rcx)) // store to c01:c31 - vmovhpd(xmm11, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm13, xmm1) - vmovlpd(xmm13, mem(rcx)) // store to c02:c32 - vmovhpd(xmm13, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm15, xmm1) - vmovlpd(xmm15, mem(rcx)) // store to c03:c33 - vmovhpd(xmm15, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - - // update c40:c73 - - vextractf128(imm(1), ymm8, xmm1) - vmovlpd(xmm8, mem(rdx)) // store to c40:c70 - vmovhpd(xmm8, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm10, xmm1) - vmovlpd(xmm10, mem(rdx)) // store to c41:c71 - vmovhpd(xmm10, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm12, xmm1) - vmovlpd(xmm12, mem(rdx)) // store to c42:c72 - vmovhpd(xmm12, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm14, xmm1) - vmovlpd(xmm14, mem(rdx)) // store to c43:c73 - vmovhpd(xmm14, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - // update c00:c33 - - vmovupd(ymm9, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm11, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm13, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm15, mem(rcx)) // store c03:c33 - - // update c40:c73 - - vmovupd(ymm8, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm10, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm12, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm14, mem(rdx)) // store c43:c73 - - - - - + + // update c00:c33 + + vmovupd(ymm9, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm11, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm13, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm15, mem(rcx)) // store c03:c33 + + // update c40:c73 + + vmovupd(ymm8, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm10, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm12, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm14, mem(rdx)) // store c43:c73 + label(.DDONE) - - vzeroupper() - + vzeroupper() - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1699,11 +990,15 @@ void bli_dgemm_sandybridge_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_sandybridge_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1718,34 +1013,36 @@ void bli_cgemm_sandybridge_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 8, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -1754,19 +1051,19 @@ void bli_cgemm_sandybridge_asm_8x4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) @@ -1776,20 +1073,20 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 0*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) @@ -1797,32 +1094,32 @@ void bli_cgemm_sandybridge_asm_8x4 vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) prefetch(0, mem(r15, 0*32)) // prefetch b_next[0*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 1 prefetch(0, mem(rax, 10*32)) vmovaps(mem(rax, 3*32), ymm1) @@ -1832,52 +1129,52 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 4*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 2 prefetch(0, mem(rax, 12*32)) vmovaps(mem(rax, 5*32), ymm1) @@ -1887,20 +1184,20 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) @@ -1908,32 +1205,32 @@ void bli_cgemm_sandybridge_asm_8x4 vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 6*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 3 prefetch(0, mem(rax, 14*32)) vmovaps(mem(rax, 7*32), ymm1) @@ -1943,74 +1240,74 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 4*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 8*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*4*8), rax) // a += 8*4 (unroll x mr) add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) @@ -2020,228 +1317,228 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 0*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*1*8), rax) // a += 8 (1 x mr) add(imm(4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab21 ab20 ab23 ab22 - // ab31 ab30 ab33 ab32 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab63 ab62 ab61 ab60 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab21 ab20 ab23 ab22 + // ab31 ab30 ab33 ab32 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba1 aba0 aba3 aba2 - // abb1 abb0 abb3 abb2 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe3 abe2 abe1 abe0 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba1 aba0 aba3 aba2 + // abb1 abb0 abb3 abb2 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe3 abe2 abe1 abe0 // abf3 abf2 abf1 abf0 ) - + vmovaps(ymm15, ymm7) vshufps(imm(0xe4), ymm13, ymm15, ymm15) vshufps(imm(0xe4), ymm7, ymm13, ymm13) - + vmovaps(ymm11, ymm7) vshufps(imm(0xe4), ymm9, ymm11, ymm11) vshufps(imm(0xe4), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vshufps(imm(0xe4), ymm12, ymm14, ymm14) vshufps(imm(0xe4), ymm7, ymm12, ymm12) - + vmovaps(ymm10, ymm7) vshufps(imm(0xe4), ymm8, ymm10, ymm10) vshufps(imm(0xe4), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab62 ab63 ab60 ab61 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab62 ab63 ab60 ab61 // ab72 ) ab73 ) ab70 ) ab71 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe2 abe3 abe0 abe1 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe2 abe3 abe0 abe1 // abf2 ) abf3 ) abf0 ) abf1 ) - + vmovaps(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm11, ymm15) vperm2f128(imm(0x30), ymm7, ymm11, ymm11) - + vmovaps(ymm13, ymm7) vperm2f128(imm(0x12), ymm13, ymm9, ymm13) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm10, ymm14) vperm2f128(imm(0x30), ymm7, ymm10, ymm10) - + vmovaps(ymm12, ymm7) vperm2f128(imm(0x12), ymm12, ymm8, ymm12) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab40 ab41 ab42 ab43 - // ab50 ab51 ab52 ab53 - // ab60 ab61 ab62 ab63 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab40 ab41 ab42 ab43 + // ab50 ab51 ab52 ab53 + // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc0 abc1 abc2 abc3 - // abd0 abd1 abd2 abd3 - // abe0 abe1 abe2 abe3 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc0 abc1 abc2 abc3 + // abd0 abd1 abd2 abd3 + // abe0 abe1 abe2 abe3 // abf0 ) abf1 ) abf2 ) abf3 ) - - - - + + + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm6) // load alpha_i and duplicate - + vpermilps(imm(0xb1), ymm15, ymm3) vmulps(ymm7, ymm15, ymm15) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm15, ymm15) - + vpermilps(imm(0xb1), ymm14, ymm2) vmulps(ymm7, ymm14, ymm14) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm14, ymm14) - + vpermilps(imm(0xb1), ymm13, ymm1) vmulps(ymm7, ymm13, ymm13) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm13, ymm13) - + vpermilps(imm(0xb1), ymm12, ymm0) vmulps(ymm7, ymm12, ymm12) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm11, ymm3) vmulps(ymm7, ymm11, ymm11) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm11, ymm11) - + vpermilps(imm(0xb1), ymm10, ymm2) vmulps(ymm7, ymm10, ymm10) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm10, ymm10) - + vpermilps(imm(0xb1), ymm9, ymm1) vmulps(ymm7, ymm9, ymm9) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm9, ymm9) - + vpermilps(imm(0xb1), ymm8, ymm0) vmulps(ymm7, ymm8, ymm8) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm6) // load beta_i and duplicate - - - - - - - + + + + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2249,410 +1546,144 @@ void bli_cgemm_sandybridge_asm_8x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - // update c00:c70 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c00,10) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c20,30) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c40,50) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c60,70) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c00,c10) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c80,90) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca0,b0) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc0,d0) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce0,f0) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c80,c90) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c01,11) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c21,31) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c41,51) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c61,71) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c01,c11) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c81,91) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca1,b1) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc1,d1) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce1,f1) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c81,c91) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c02,12) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c22,32) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c42,52) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c62,72) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c02,c12) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c82,92) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca2,b2) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc2,d2) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce2,f2) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c82,c92) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c03,13) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c23,33) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c43,53) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c63,73) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c03,c13) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c83,93) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca3,b3) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc3,d3) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce3,f3) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c83,c93) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - // update c00:c70 - - vmovups(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovups(mem(rdx), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - // update c00:c70 - - vmovups(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovups(mem(rdx), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovups(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovups(mem(rdx), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovups(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovups(mem(rdx), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - + + // update c00:c70 + + vmovups(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c00:c70 + add(rdi, rcx) // c += cs_c; + + // update c80:cf0 + + vmovups(mem(rdx), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c80:cf0 + add(rdi, rdx) // c += cs_c; + + // update c00:c70 + + vmovups(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c01:c71 + add(rdi, rcx) // c += cs_c; + + // update c81:cf1 + + vmovups(mem(rdx), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c81:cf1 + add(rdi, rdx) // c += cs_c; + + // update c02:c72 + + vmovups(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c02:c72 + add(rdi, rcx) // c += cs_c; + + // update c82:cf2 + + vmovups(mem(rdx), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c82:cf2 + add(rdi, rdx) // c += cs_c; + + // update c03:c73 + + vmovups(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c03:c73 + add(rdi, rcx) // c += cs_c; + + // update c83:cf3 + + vmovups(mem(rdx), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c83:cf3 + add(rdi, rdx) // c += cs_c; + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORBZ) // jump to column storage case - - - - label(.CGENSTORBZ) - - // update c00:c70 - - vextractf128(imm(1), ymm15, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm15, mem(rcx)) // store (c00,c10) - vmovhpd(xmm15, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vextractf128(imm(1), ymm14, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm14, mem(rdx)) // store (c80,c90) - vmovhpd(xmm14, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vextractf128(imm(1), ymm13, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm13, mem(rcx)) // store (c01,c11) - vmovhpd(xmm13, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vextractf128(imm(1), ymm12, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm12, mem(rdx)) // store (c81,c91) - vmovhpd(xmm12, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vextractf128(imm(1), ymm11, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm11, mem(rcx)) // store (c02,c12) - vmovhpd(xmm11, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vextractf128(imm(1), ymm10, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm10, mem(rdx)) // store (c82,c92) - vmovhpd(xmm10, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vextractf128(imm(1), ymm9, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm9, mem(rcx)) // store (c03,c13) - vmovhpd(xmm9, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vextractf128(imm(1), ymm8, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm8, mem(rdx)) // store (c83,c93) - vmovhpd(xmm8, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(ymm15, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm14, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm13, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm12, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm11, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm10, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm9, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm8, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - - - + + vmovups(ymm15, mem(rcx)) // store c00:c70 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm14, mem(rdx)) // store c80:cf0 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm13, mem(rcx)) // store c01:c71 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm12, mem(rdx)) // store c81:cf1 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm11, mem(rcx)) // store c02:c72 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm10, mem(rdx)) // store c82:cf2 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm9, mem(rcx)) // store c03:c73 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm8, mem(rdx)) // store c83:cf3 + add(rdi, rdx) // c += cs_c; + label(.CDONE) - - vzeroupper() - + vzeroupper() - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2660,13 +1691,17 @@ void bli_cgemm_sandybridge_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_sandybridge_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2681,34 +1716,36 @@ void bli_zgemm_sandybridge_asm_4x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 4, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovddup(mem(rbx, 0+0*32), ymm2) vmovddup(mem(rbx, 0+1*32), ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -2717,18 +1754,18 @@ void bli_zgemm_sandybridge_asm_4x4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2737,7 +1774,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+0*32), ymm2) @@ -2745,45 +1782,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+1*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2792,7 +1829,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 18*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+2*32), ymm2) @@ -2800,45 +1837,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+3*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+4*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+5*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2847,7 +1884,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 20*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+4*32), ymm2) @@ -2855,45 +1892,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+5*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+6*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+7*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2902,7 +1939,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 22*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+6*32), ymm2) @@ -2910,67 +1947,67 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+7*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+8*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+9*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 8*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2979,7 +2016,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+0*32), ymm2) @@ -2987,166 +2024,166 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+1*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*1*16), rax) // a += 4 (1 x mr) add(imm(4*1*16), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab21 ab20 ab23 ab22 // ab31 ) ab30 ) ab33 ) ab32 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab61 ab60 ab63 ab62 // ab71 ) ab70 ) ab73 ) ab72 ) - - + + vmovapd(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm13, ymm15) vperm2f128(imm(0x30), ymm7, ymm13, ymm13) - + vmovapd(ymm11, ymm7) vperm2f128(imm(0x12), ymm11, ymm9, ymm11) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm12, ymm14) vperm2f128(imm(0x30), ymm7, ymm12, ymm12) - + vmovapd(ymm10, ymm7) vperm2f128(imm(0x12), ymm10, ymm8, ymm10) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm6) // load alpha_i and duplicate - + vpermilpd(imm(0x5), ymm15, ymm3) vmulpd(ymm7, ymm15, ymm15) vmulpd(ymm6, ymm3, ymm3) vaddsubpd(ymm3, ymm15, ymm15) - + vpermilpd(imm(0x5), ymm14, ymm2) vmulpd(ymm7, ymm14, ymm14) vmulpd(ymm6, ymm2, ymm2) vaddsubpd(ymm2, ymm14, ymm14) - + vpermilpd(imm(0x5), ymm13, ymm1) vmulpd(ymm7, ymm13, ymm13) vmulpd(ymm6, ymm1, ymm1) vaddsubpd(ymm1, ymm13, ymm13) - + vpermilpd(imm(0x5), ymm12, ymm0) vmulpd(ymm7, ymm12, ymm12) vmulpd(ymm6, ymm0, ymm0) vaddsubpd(ymm0, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm11, ymm3) vmulpd(ymm7, ymm11, ymm11) vmulpd(ymm6, ymm3, ymm3) vaddsubpd(ymm3, ymm11, ymm11) - + vpermilpd(imm(0x5), ymm10, ymm2) vmulpd(ymm7, ymm10, ymm10) vmulpd(ymm6, ymm2, ymm2) vaddsubpd(ymm2, ymm10, ymm10) - + vpermilpd(imm(0x5), ymm9, ymm1) vmulpd(ymm7, ymm9, ymm9) vmulpd(ymm6, ymm1, ymm1) vaddsubpd(ymm1, ymm9, ymm9) - + vpermilpd(imm(0x5), ymm8, ymm0) vmulpd(ymm7, ymm8, ymm8) vmulpd(ymm6, ymm0, ymm0) vaddsubpd(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm6) // load beta_i and duplicate - - - - - - - + + + + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) lea(mem(, rsi, 2), rsi) lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -3154,355 +2191,142 @@ void bli_zgemm_sandybridge_asm_4x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - // update c00:c30 - - vmovupd(mem(rcx), xmm0) // load (c00,c10) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c20,c30) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), xmm0) // load (c40,c50) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c60,c70) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), xmm0) // load (c01,c11) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c21,c31) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), xmm0) // load (c41,c51) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c61,c71) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), xmm0) // load (c02,c12) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c22,c32) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), xmm0) // load (c42,c52) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c62,c72) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), xmm0) // load (c03,c13) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c23,c33) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), xmm0) // load (c43,c53) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c63,c73) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - // update c00:c30 - - vmovupd(mem(rcx), ymm0) // load c00:c30 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), ymm0) // load c40:c70 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), ymm0) // load c01:c31 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), ymm0) // load c41:c71 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), ymm0) // load c02:c32 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), ymm0) // load c42:c72 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), ymm0) // load c03:c33 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), ymm0) // load c43:c73 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c43:c73 - - - - jmp(.ZDONE) // jump to end. - - - + + // update c00:c30 + + vmovupd(mem(rcx), ymm0) // load c00:c30 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + // update c40:c70 + + vmovupd(mem(rdx), ymm0) // load c40:c70 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + // update c01:c31 + + vmovupd(mem(rcx), ymm0) // load c01:c31 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + // update c41:c71 + + vmovupd(mem(rdx), ymm0) // load c41:c71 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + // update c02:c32 + + vmovupd(mem(rcx), ymm0) // load c02:c32 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + // update c42:c72 + + vmovupd(mem(rdx), ymm0) // load c42:c72 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + // update c03:c33 + + vmovupd(mem(rcx), ymm0) // load c03:c33 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c03:c33 + add(rdi, rcx) // c += cs_c; + + // update c43:c73 + + vmovupd(mem(rdx), ymm0) // load c43:c73 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c43:c73 + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - // update c00:c30 - - vextractf128(imm(1), ymm15, xmm2) - vmovupd(xmm15, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vextractf128(imm(1), ymm14, xmm2) - vmovupd(xmm14, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vextractf128(imm(1), ymm13, xmm2) - vmovupd(xmm13, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vextractf128(imm(1), ymm12, xmm2) - vmovupd(xmm12, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vextractf128(imm(1), ymm11, xmm2) - vmovupd(xmm11, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vextractf128(imm(1), ymm10, xmm2) - vmovupd(xmm10, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vextractf128(imm(1), ymm9, xmm2) - vmovupd(xmm9, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vextractf128(imm(1), ymm8, xmm2) - vmovupd(xmm8, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovupd(ymm15, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm14, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm13, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm12, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm11, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm10, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm9, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm8, mem(rdx)) // store c43:c73 - - - - - + + vmovupd(ymm15, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm14, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm13, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm12, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm11, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm10, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm9, mem(rcx)) // store c03:c33 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm8, mem(rdx)) // store c43:c73 + label(.ZDONE) - - vzeroupper() - + vzeroupper() - + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -3510,6 +2334,8 @@ void bli_zgemm_sandybridge_asm_4x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c index 6a1bb04f54..6bf991082b 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c @@ -32,14 +32,17 @@ */ -#include +#include +#include #include "blis.h" #if 0 void bli_sgemm_sandybridge_int_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -52,11 +55,11 @@ void bli_sgemm_sandybridge_int_8x8 } #endif - - void bli_dgemm_sandybridge_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -66,19 +69,22 @@ void bli_dgemm_sandybridge_int_8x4 cntx_t* restrict cntx ) { + //void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 2; - uint64_t k_left = k0 % 2; + uint64_t k_iter = k / 2; + uint64_t k_left = k % 2; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t i; - double *c00, *c01, *c02, *c03; - double *c40, *c41, *c42, *c43; + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + + double *c00, *c01, *c02, *c03; + double *c40, *c41, *c42, *c43; // Quad registers. __m256d va0_3, va4_7; @@ -87,23 +93,20 @@ void bli_dgemm_sandybridge_int_8x4 __m256d vb; __m256d vB0; - __m256d va0_3b_0, va4_7b_0; - __m256d va0_3b_1, va4_7b_1; - __m256d va0_3b_2, va4_7b_2; - __m256d va0_3b_3, va4_7b_3; - - __m256d va0_3b0, va4_7b0; - __m256d va0_3b1, va4_7b1; - __m256d va0_3b2, va4_7b2; - __m256d va0_3b3, va4_7b3; + __m256d va0_3b_0, va4_7b_0; + __m256d va0_3b_1, va4_7b_1; + __m256d va0_3b_2, va4_7b_2; + __m256d va0_3b_3, va4_7b_3; + __m256d va0_3b0, va4_7b0; + __m256d va0_3b1, va4_7b1; + __m256d va0_3b2, va4_7b2; + __m256d va0_3b3, va4_7b3; - __m256d valpha, vbeta, vtmp; + __m256d valpha, vbeta, vtmp; __m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3; __m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3; - __m128d aa, bb; - __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) ); __asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) ); __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) ); @@ -129,19 +132,19 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_3 = _mm256_setzero_pd(); // Load va0_3 - va0_3 = _mm256_load_pd( a ); + va0_3 = _mm256_load_pd( a ); // Load va4_7 - va4_7 = _mm256_load_pd( a + 4 ); + va4_7 = _mm256_load_pd( a + 4 ); - // Load vb (b0,b1,b2,b3) - vb0 = _mm256_load_pd( b ); + // Load vb (b0,b1,b2,b3) + vb0 = _mm256_load_pd( b ); for( i = 0; i < k_iter; ++i ) { __asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) ); // Load va0_3 (Prefetch) - vA0_3 = _mm256_load_pd( a + 8 ); + vA0_3 = _mm256_load_pd( a + 8 ); // Iteration 0. vtmp = _mm256_mul_pd( va0_3, vb0 ); @@ -151,10 +154,10 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); // Load va4_7 (Prefetch) - vA4_7 = _mm256_load_pd( a + 12 ); + vA4_7 = _mm256_load_pd( a + 12 ); // Shuffle vb (b1,b0,b3,b2) - vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 ); + vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb1 ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); @@ -163,10 +166,10 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); // Permute vb (b3,b2,b1,b0) - vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); + vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); // Load vb (b0,b1,b2,b3) (Prefetch) - vB0 = _mm256_load_pd( b + 4 ); + vB0 = _mm256_load_pd( b + 4 ); vtmp = _mm256_mul_pd( va0_3, vb2 ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); @@ -175,7 +178,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Shuffle vb (b3,b2,b1,b0) - vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); + vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb3 ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -186,14 +189,14 @@ void bli_dgemm_sandybridge_int_8x4 // Iteration 1. __asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) ); - + // Load va0_3 (Next iteration) - va0_3 = _mm256_load_pd( a + 16 ); + va0_3 = _mm256_load_pd( a + 16 ); vtmp = _mm256_mul_pd( vA0_3, vB0 ); va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp ); - vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 ); + vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 ); vtmp = _mm256_mul_pd( vA4_7, vB0 ); va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); @@ -202,9 +205,9 @@ void bli_dgemm_sandybridge_int_8x4 va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); // Load va4_7 (Next iteration) - va4_7 = _mm256_load_pd( a + 20 ); + va4_7 = _mm256_load_pd( a + 20 ); - vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); + vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); vtmp = _mm256_mul_pd( vA4_7, vb1 ); va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); @@ -212,13 +215,13 @@ void bli_dgemm_sandybridge_int_8x4 vtmp = _mm256_mul_pd( vA0_3, vb2 ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); - vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); + vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); vtmp = _mm256_mul_pd( vA4_7, vb2 ); va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Load vb0(Next iteration) - vb0 = _mm256_load_pd( b + 8 ); + vb0 = _mm256_load_pd( b + 8 ); vtmp = _mm256_mul_pd( vA0_3, vb3 ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -236,12 +239,12 @@ void bli_dgemm_sandybridge_int_8x4 // Iteration 0. // Load va0_3 - va0_3 = _mm256_load_pd( a ); + va0_3 = _mm256_load_pd( a ); // Load va4_7 - va4_7 = _mm256_load_pd( a + 4 ); + va4_7 = _mm256_load_pd( a + 4 ); - // Load vb (b0,b1,b2,b3) - vb = _mm256_load_pd( b ); + // Load vb (b0,b1,b2,b3) + vb = _mm256_load_pd( b ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp ); @@ -250,7 +253,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); // Shuffle vb (b1,b0,b3,b2) - vb = _mm256_shuffle_pd( vb, vb, 0x5 ); + vb = _mm256_shuffle_pd( vb, vb, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); @@ -259,7 +262,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); // Permute vb (b3,b2,b1,b0) - vb = _mm256_permute2f128_pd( vb, vb, 0x1 ); + vb = _mm256_permute2f128_pd( vb, vb, 0x1 ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); @@ -268,7 +271,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Shuffle vb (b3,b2,b1,b0) - vb = _mm256_shuffle_pd( vb, vb, 0x5 ); + vb = _mm256_shuffle_pd( vb, vb, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -309,131 +312,73 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 ); va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 ); - if( rs_c == 1 ) + __m128d vzero = _mm_setzero_pd( ); + + if( _mm_comieq_sd( _mm256_castpd256_pd128(vbeta), vzero ) ) { // Calculate address - c00 = ( c + 0*rs_c + 0*cs_c ); - // Load - //vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); - vc0_3_0 = _mm256_load_pd( c00 ); + c00 = ( c + 0 + 0*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b0); - // Scale by beta - vc0_3_0 = _mm256_mul_pd( vbeta, vc0_3_0 ); - // Add gemm result - vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp ); // Store back to memory - _mm256_store_pd( c00, vc0_3_0 ); - + _mm256_store_pd( c00, vtmp ); + // Calculate address - c40 = ( c + 4*rs_c + 0*cs_c ); - // Load - //vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); - vc4_7_0 = _mm256_load_pd( c40 ); + c40 = ( c + 4 + 0*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b0); - // Scale by beta - vc4_7_0 = _mm256_mul_pd( vbeta, vc4_7_0 ); - // Add gemm result - vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp ); // Store back to memory - _mm256_store_pd( c40, vc4_7_0 ); - + _mm256_store_pd( c40, vtmp ); + // Calculate address - c01 = ( c + 0*rs_c + 1*cs_c ); - // Load - //vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); - vc0_3_1 = _mm256_load_pd( c01 ); + c01 = ( c + 0 + 1*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b1); - // Scale by beta - vc0_3_1 = _mm256_mul_pd( vbeta, vc0_3_1 ); - // Add gemm result - vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp ); // Store back to memory - _mm256_store_pd( c01, vc0_3_1 ); - + _mm256_store_pd( c01, vtmp ); + // Calculate address - c41 = ( c + 4*rs_c + 1*cs_c ); - // Load - //vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); - vc4_7_1 = _mm256_load_pd( c41 ); + c41 = ( c + 4 + 1*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b1); - // Scale by beta - vc4_7_1 = _mm256_mul_pd( vbeta, vc4_7_1 ); - // Add gemm result - vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp ); // Store back to memory - _mm256_store_pd( c41, vc4_7_1 ); - + _mm256_store_pd( c41, vtmp ); + // Calculate address - c02 = ( c + 0*rs_c + 2*cs_c ); - // Load - //vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); - vc0_3_2 = _mm256_load_pd( c02 ); + c02 = ( c + 0 + 2*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b2); - // Scale by beta - vc0_3_2 = _mm256_mul_pd( vbeta, vc0_3_2 ); - // Add gemm result - vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp ); // Store back to memory - _mm256_store_pd( c02, vc0_3_2 ); - + _mm256_store_pd( c02, vtmp ); + // Calculate address - c42 = ( c + 4*rs_c + 2*cs_c ); - // Load - //vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); - vc4_7_2 = _mm256_load_pd( c42 ); + c42 = ( c + 4 + 2*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b2); - // Scale by beta - vc4_7_2 = _mm256_mul_pd( vbeta, vc4_7_2 ); - // Add gemm result - vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp ); // Store back to memory - _mm256_store_pd( c42, vc4_7_2 ); - + _mm256_store_pd( c42, vtmp ); + // Calculate address - c03 = ( c + 0*rs_c + 3*cs_c ); - // Load - //vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); - vc0_3_3 = _mm256_load_pd( c03 ); + c03 = ( c + 0 + 3*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b3); - // Scale by beta - vc0_3_3 = _mm256_mul_pd( vbeta, vc0_3_3 ); - // Add gemm result - vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp ); // Store back to memory - _mm256_store_pd( c03, vc0_3_3 ); - + _mm256_store_pd( c03, vtmp ); + // Calculate address - c43 = ( c + 4*rs_c + 3*cs_c ); - // Load - //vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); - vc4_7_3 = _mm256_load_pd( c43 ); + c43 = ( c + 4 + 3*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b3); - // Scale by beta - vc4_7_3 = _mm256_mul_pd( vbeta, vc4_7_3 ); - // Add gemm result - vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); // Store back to memory - _mm256_store_pd( c43, vc4_7_3 ); - + _mm256_store_pd( c43, vtmp ); } else { // Calculate address - c00 = ( c + 0*rs_c + 0*cs_c ); + c00 = ( c + 0 + 0*cs_c ); // Load - //vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); - vc0_3_0 = _mm256_set_pd( *(c + 3*rs_c + 0*cs_c ), - *(c + 2*rs_c + 0*cs_c ), - *(c + 1*rs_c + 0*cs_c ), - *(c + 0*rs_c + 0*cs_c ) ); + //vc0_3_0 = _mm256_load_pd( c + 0 + 0*cs_c ); + vc0_3_0 = _mm256_load_pd( c00 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b0); // Scale by beta @@ -441,24 +386,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp ); // Store back to memory - //_mm256_store_pd( c00, vc0_3_0 ); - - aa = _mm256_extractf128_pd( vc0_3_0, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_0, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 0*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 0*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 0*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 0*cs_c, bb ); + _mm256_store_pd( c00, vc0_3_0 ); // Calculate address - c40 = ( c + 4*rs_c + 0*cs_c ); + c40 = ( c + 4 + 0*cs_c ); // Load - //vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); - vc4_7_0 = _mm256_set_pd( *(c + 7*rs_c + 0*cs_c ), - *(c + 6*rs_c + 0*cs_c ), - *(c + 5*rs_c + 0*cs_c ), - *(c + 4*rs_c + 0*cs_c ) ); + //vc4_7_0 = _mm256_load_pd( c + 4 + 0*cs_c ); + vc4_7_0 = _mm256_load_pd( c40 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b0); // Scale by beta @@ -466,24 +400,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp ); // Store back to memory - //_mm256_store_pd( c40, vc4_7_0 ); - - aa = _mm256_extractf128_pd( vc4_7_0, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_0, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 0*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 0*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 0*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 0*cs_c, bb ); + _mm256_store_pd( c40, vc4_7_0 ); // Calculate address - c01 = ( c + 0*rs_c + 1*cs_c ); + c01 = ( c + 0 + 1*cs_c ); // Load - //vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); - vc0_3_1 = _mm256_set_pd( *(c + 3*rs_c + 1*cs_c ), - *(c + 2*rs_c + 1*cs_c ), - *(c + 1*rs_c + 1*cs_c ), - *(c + 0*rs_c + 1*cs_c ) ); + //vc0_3_1 = _mm256_load_pd( c + 0 + 1*cs_c ); + vc0_3_1 = _mm256_load_pd( c01 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b1); // Scale by beta @@ -491,24 +414,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp ); // Store back to memory - //_mm256_store_pd( c01, vc0_3_1 ); - - aa = _mm256_extractf128_pd( vc0_3_1, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_1, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 1*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 1*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 1*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 1*cs_c, bb ); + _mm256_store_pd( c01, vc0_3_1 ); // Calculate address - c41 = ( c + 4*rs_c + 1*cs_c ); + c41 = ( c + 4 + 1*cs_c ); // Load - //vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); - vc4_7_1 = _mm256_set_pd( *(c + 7*rs_c + 1*cs_c ), - *(c + 6*rs_c + 1*cs_c ), - *(c + 5*rs_c + 1*cs_c ), - *(c + 4*rs_c + 1*cs_c ) ); + //vc4_7_1 = _mm256_load_pd( c + 4 + 1*cs_c ); + vc4_7_1 = _mm256_load_pd( c41 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b1); // Scale by beta @@ -516,24 +428,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp ); // Store back to memory - //_mm256_store_pd( c41, vc4_7_1 ); - - aa = _mm256_extractf128_pd( vc4_7_1, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_1, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 1*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 1*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 1*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 1*cs_c, bb ); + _mm256_store_pd( c41, vc4_7_1 ); // Calculate address - c02 = ( c + 0*rs_c + 2*cs_c ); + c02 = ( c + 0 + 2*cs_c ); // Load - //vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); - vc0_3_2 = _mm256_set_pd( *(c + 3*rs_c + 2*cs_c ), - *(c + 2*rs_c + 2*cs_c ), - *(c + 1*rs_c + 2*cs_c ), - *(c + 0*rs_c + 2*cs_c ) ); + //vc0_3_2 = _mm256_load_pd( c + 0 + 2*cs_c ); + vc0_3_2 = _mm256_load_pd( c02 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b2); // Scale by beta @@ -541,24 +442,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp ); // Store back to memory - //_mm256_store_pd( c02, vc0_3_2 ); - - aa = _mm256_extractf128_pd( vc0_3_2, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_2, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 2*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 2*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 2*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 2*cs_c, bb ); + _mm256_store_pd( c02, vc0_3_2 ); // Calculate address - c42 = ( c + 4*rs_c + 2*cs_c ); + c42 = ( c + 4 + 2*cs_c ); // Load - //vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); - vc4_7_2 = _mm256_set_pd( *(c + 7*rs_c + 2*cs_c ), - *(c + 6*rs_c + 2*cs_c ), - *(c + 5*rs_c + 2*cs_c ), - *(c + 4*rs_c + 2*cs_c ) ); + //vc4_7_2 = _mm256_load_pd( c + 4 + 2*cs_c ); + vc4_7_2 = _mm256_load_pd( c42 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b2); // Scale by beta @@ -566,24 +456,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp ); // Store back to memory - //_mm256_store_pd( c42, vc4_7_2 ); - - aa = _mm256_extractf128_pd( vc4_7_2, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_2, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 2*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 2*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 2*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 2*cs_c, bb ); + _mm256_store_pd( c42, vc4_7_2 ); // Calculate address - c03 = ( c + 0*rs_c + 3*cs_c ); + c03 = ( c + 0 + 3*cs_c ); // Load - //vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); - vc0_3_3 = _mm256_set_pd( *(c + 3*rs_c + 3*cs_c ), - *(c + 2*rs_c + 3*cs_c ), - *(c + 1*rs_c + 3*cs_c ), - *(c + 0*rs_c + 3*cs_c ) ); + //vc0_3_3 = _mm256_load_pd( c + 0 + 3*cs_c ); + vc0_3_3 = _mm256_load_pd( c03 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b3); // Scale by beta @@ -591,24 +470,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp ); // Store back to memory - //_mm256_store_pd( c03, vc0_3_3 ); - - aa = _mm256_extractf128_pd( vc0_3_3, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_3, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 3*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 3*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 3*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 3*cs_c, bb ); + _mm256_store_pd( c03, vc0_3_3 ); // Calculate address - c43 = ( c + 4*rs_c + 3*cs_c ); + c43 = ( c + 4 + 3*cs_c ); // Load - //vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); - vc4_7_3 = _mm256_set_pd( *(c + 7*rs_c + 3*cs_c ), - *(c + 6*rs_c + 3*cs_c ), - *(c + 5*rs_c + 3*cs_c ), - *(c + 4*rs_c + 3*cs_c ) ); + //vc4_7_3 = _mm256_load_pd( c + 4 + 3*cs_c ); + vc4_7_3 = _mm256_load_pd( c43 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b3); // Scale by beta @@ -616,17 +484,10 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); // Store back to memory - //_mm256_store_pd( c43, vc4_7_3 ); - - aa = _mm256_extractf128_pd( vc4_7_3, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_3, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 3*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 3*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 3*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 3*cs_c, bb ); + _mm256_store_pd( c43, vc4_7_3 ); } + GEMM_UKR_FLUSH_CT( d ); } @@ -634,7 +495,9 @@ void bli_dgemm_sandybridge_int_8x4 #if 0 void bli_cgemm_sandybridge_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -652,7 +515,9 @@ void bli_cgemm_sandybridge_int_8x4 #if 0 void bli_zgemm_sandybridge_int_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c index 3a20cd8618..9943a170be 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c @@ -287,24 +287,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; -void bli_dgemm_skx_asm_16x12_l2( - dim_t k_, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, - cntx_t* restrict cntx - ) +void bli_dgemm_skx_asm_16x12_l2 + ( + dim_t m, + dim_t n, + dim_t k_, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) { (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 16, 12, false ); BEGIN_ASM() @@ -464,62 +468,26 @@ void bli_dgemm_skx_asm_16x12_l2( MOV(RAX, VAR(cs_c)) LEA(RAX, MEM(,RAX,8)) - MOV(RBX, VAR(rs_c)) - LEA(RBX, MEM(,RBX,8)) - - // Check if C is column stride. If not, jump to the slow scattered update - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) - - VCOMISD(XMM(1), XMM(7)) - JE(COLSTORBZ) - UPDATE_C( 8, 9,10,11) - UPDATE_C(12,13,14,15) - UPDATE_C(16,17,18,19) - UPDATE_C(20,21,22,23) - UPDATE_C(24,25,26,27) - UPDATE_C(28,29,30,31) + VCOMISD(XMM(1), XMM(7)) + JE(COLSTORBZ) - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 8, 9,10,11) - UPDATE_C_BZ(12,13,14,15) - UPDATE_C_BZ(16,17,18,19) - UPDATE_C_BZ(20,21,22,23) - UPDATE_C_BZ(24,25,26,27) - UPDATE_C_BZ(28,29,30,31) + UPDATE_C( 8, 9,10,11) + UPDATE_C(12,13,14,15) + UPDATE_C(16,17,18,19) + UPDATE_C(20,21,22,23) + UPDATE_C(24,25,26,27) + UPDATE_C(28,29,30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVDQA64(ZMM(2), MEM(RDI,0*64)) - VMOVDQA64(ZMM(3), MEM(RDI,1*64)) - VPBROADCASTQ(ZMM(6), RBX) - VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) - VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) - - VCOMISD(XMM(1), XMM(7)) - JE(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_ROW_SCATTERED(12,13,14,15) - UPDATE_C_ROW_SCATTERED(16,17,18,19) - UPDATE_C_ROW_SCATTERED(20,21,22,23) - UPDATE_C_ROW_SCATTERED(24,25,26,27) - UPDATE_C_ROW_SCATTERED(28,29,30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15) - UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19) - UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23) - UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27) - UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 8, 9,10,11) + UPDATE_C_BZ(12,13,14,15) + UPDATE_C_BZ(16,17,18,19) + UPDATE_C_BZ(20,21,22,23) + UPDATE_C_BZ(24,25,26,27) + UPDATE_C_BZ(28,29,30,31) LABEL(END) @@ -535,8 +503,7 @@ void bli_dgemm_skx_asm_16x12_l2( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -545,4 +512,6 @@ void bli_dgemm_skx_asm_16x12_l2( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 136f315323..e3bc52041d 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -153,24 +153,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; -void bli_dgemm_skx_asm_16x14( - dim_t k_, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, - cntx_t* restrict cntx - ) +void bli_dgemm_skx_asm_16x14 + ( + dim_t m, + dim_t n, + dim_t k_, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) { (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_*8; - const int64_t cs_c = cs_c_*8; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 16, 14, false ); BEGIN_ASM() @@ -220,6 +224,8 @@ void bli_dgemm_skx_asm_16x14( MOV(R12, VAR(rs_c)) MOV(R10, VAR(cs_c)) + LEA(R12, MEM(,R12,8)) + LEA(R10, MEM(,R10,8)) MOV(RDI, RSI) AND(RSI, IMM(3)) @@ -320,119 +326,41 @@ void bli_dgemm_skx_asm_16x14( MOV(RAX, R12) MOV(RBX, R10) - // Check if C is column stride. - CMP(RAX, IMM(8)) - JNE(SCATTEREDUPDATE) - - VCOMISD(XMM(1), XMM(2)) - JE(COLSTORBZ) - - UPDATE_C( 4, 5) - UPDATE_C( 6, 7) - UPDATE_C( 8, 9) - UPDATE_C(10,11) - UPDATE_C(12,13) - UPDATE_C(14,15) - UPDATE_C(16,17) - UPDATE_C(18,19) - UPDATE_C(20,21) - UPDATE_C(22,23) - UPDATE_C(24,25) - UPDATE_C(26,27) - UPDATE_C(28,29) - UPDATE_C(30,31) - - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 4, 5) - UPDATE_C_BZ( 6, 7) - UPDATE_C_BZ( 8, 9) - UPDATE_C_BZ(10,11) - UPDATE_C_BZ(12,13) - UPDATE_C_BZ(14,15) - UPDATE_C_BZ(16,17) - UPDATE_C_BZ(18,19) - UPDATE_C_BZ(20,21) - UPDATE_C_BZ(22,23) - UPDATE_C_BZ(24,25) - UPDATE_C_BZ(26,27) - UPDATE_C_BZ(28,29) - UPDATE_C_BZ(30,31) + VCOMISD(XMM(1), XMM(2)) + JE(COLSTORBZ) + + UPDATE_C( 4, 5) + UPDATE_C( 6, 7) + UPDATE_C( 8, 9) + UPDATE_C(10,11) + UPDATE_C(12,13) + UPDATE_C(14,15) + UPDATE_C(16,17) + UPDATE_C(18,19) + UPDATE_C(20,21) + UPDATE_C(22,23) + UPDATE_C(24,25) + UPDATE_C(26,27) + UPDATE_C(28,29) + UPDATE_C(30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - VMULPD(ZMM( 4), ZMM( 4), ZMM(0)) - VMULPD(ZMM( 5), ZMM( 5), ZMM(0)) - VMULPD(ZMM( 6), ZMM( 6), ZMM(0)) - VMULPD(ZMM( 7), ZMM( 7), ZMM(0)) - VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) - VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) - VMULPD(ZMM(10), ZMM(10), ZMM(0)) - VMULPD(ZMM(11), ZMM(11), ZMM(0)) - VMULPD(ZMM(12), ZMM(12), ZMM(0)) - VMULPD(ZMM(13), ZMM(13), ZMM(0)) - VMULPD(ZMM(14), ZMM(14), ZMM(0)) - VMULPD(ZMM(15), ZMM(15), ZMM(0)) - VMULPD(ZMM(16), ZMM(16), ZMM(0)) - VMULPD(ZMM(17), ZMM(17), ZMM(0)) - VMULPD(ZMM(18), ZMM(18), ZMM(0)) - VMULPD(ZMM(19), ZMM(19), ZMM(0)) - VMULPD(ZMM(20), ZMM(20), ZMM(0)) - VMULPD(ZMM(21), ZMM(21), ZMM(0)) - VMULPD(ZMM(22), ZMM(22), ZMM(0)) - VMULPD(ZMM(23), ZMM(23), ZMM(0)) - VMULPD(ZMM(24), ZMM(24), ZMM(0)) - VMULPD(ZMM(25), ZMM(25), ZMM(0)) - VMULPD(ZMM(26), ZMM(26), ZMM(0)) - VMULPD(ZMM(27), ZMM(27), ZMM(0)) - VMULPD(ZMM(28), ZMM(28), ZMM(0)) - VMULPD(ZMM(29), ZMM(29), ZMM(0)) - VMULPD(ZMM(30), ZMM(30), ZMM(0)) - VMULPD(ZMM(31), ZMM(31), ZMM(0)) - - VCOMISD(XMM(1), XMM(2)) - - MOV(RDI, VAR(offsetPtr)) - VPBROADCASTQ(ZMM(0), RAX) - VPMULLQ(ZMM(2), ZMM(0), MEM(RDI)) - VPMULLQ(ZMM(3), ZMM(0), MEM(RDI,64)) - - JE(SCATTERBZ) - - UPDATE_C_COL_SCATTERED( 4, 5) - UPDATE_C_COL_SCATTERED( 6, 7) - UPDATE_C_COL_SCATTERED( 8, 9) - UPDATE_C_COL_SCATTERED(10,11) - UPDATE_C_COL_SCATTERED(12,13) - UPDATE_C_COL_SCATTERED(14,15) - UPDATE_C_COL_SCATTERED(16,17) - UPDATE_C_COL_SCATTERED(18,19) - UPDATE_C_COL_SCATTERED(20,21) - UPDATE_C_COL_SCATTERED(22,23) - UPDATE_C_COL_SCATTERED(24,25) - UPDATE_C_COL_SCATTERED(26,27) - UPDATE_C_COL_SCATTERED(28,29) - UPDATE_C_COL_SCATTERED(30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_COL_SCATTERED( 4, 5) - UPDATE_C_BZ_COL_SCATTERED( 6, 7) - UPDATE_C_BZ_COL_SCATTERED( 8, 9) - UPDATE_C_BZ_COL_SCATTERED(10,11) - UPDATE_C_BZ_COL_SCATTERED(12,13) - UPDATE_C_BZ_COL_SCATTERED(14,15) - UPDATE_C_BZ_COL_SCATTERED(16,17) - UPDATE_C_BZ_COL_SCATTERED(18,19) - UPDATE_C_BZ_COL_SCATTERED(20,21) - UPDATE_C_BZ_COL_SCATTERED(22,23) - UPDATE_C_BZ_COL_SCATTERED(24,25) - UPDATE_C_BZ_COL_SCATTERED(26,27) - UPDATE_C_BZ_COL_SCATTERED(28,29) - UPDATE_C_BZ_COL_SCATTERED(30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 4, 5) + UPDATE_C_BZ( 6, 7) + UPDATE_C_BZ( 8, 9) + UPDATE_C_BZ(10,11) + UPDATE_C_BZ(12,13) + UPDATE_C_BZ(14,15) + UPDATE_C_BZ(16,17) + UPDATE_C_BZ(18,19) + UPDATE_C_BZ(20,21) + UPDATE_C_BZ(22,23) + UPDATE_C_BZ(24,25) + UPDATE_C_BZ(26,27) + UPDATE_C_BZ(28,29) + UPDATE_C_BZ(30,31) LABEL(END) @@ -449,8 +377,7 @@ void bli_dgemm_skx_asm_16x14( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -459,4 +386,6 @@ void bli_dgemm_skx_asm_16x14( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c index 40af496140..8808449b65 100644 --- a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c +++ b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c @@ -317,24 +317,28 @@ ahead*/ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; -void bli_sgemm_skx_asm_32x12_l2( - dim_t k_, - float* restrict alpha, - float* restrict a, - float* restrict b, - float* restrict beta, - float* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, - cntx_t* restrict cntx - ) +void bli_sgemm_skx_asm_32x12_l2 + ( + dim_t m, + dim_t n, + dim_t k_, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) { (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( s, 32, 12, false ); BEGIN_ASM() @@ -381,7 +385,7 @@ void bli_sgemm_skx_asm_32x12_l2( #endif #ifdef PREFETCH_B_BEFORE - /* Prefetching 3 cachlines of B (4 iterations worth of data + /* Prefetching 3 cachlines of B (4 iterations worth of data (12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */ PREFETCH(0, MEM(RBX,0*64)) PREFETCH(0, MEM(RBX,1*64)) @@ -485,66 +489,26 @@ void bli_sgemm_skx_asm_32x12_l2( MOV(RAX, VAR(cs_c)) LEA(RAX, MEM(,RAX,4)) - MOV(RBX, VAR(rs_c)) - LEA(RBX, MEM(,RBX,4)) - - - // Check if C is column major (rs_c = 1). If not, jump to the slow scattered update - CMP(RBX, IMM(4)) - JNE(SCATTEREDUPDATE) - - VCOMISS(XMM(1), XMM(7)) - JE(COLSTORBZ) - UPDATE_C( 8, 9,10,11) - UPDATE_C(12,13,14,15) - UPDATE_C(16,17,18,19) - UPDATE_C(20,21,22,23) - UPDATE_C(24,25,26,27) - UPDATE_C(28,29,30,31) + VCOMISS(XMM(1), XMM(7)) + JE(COLSTORBZ) - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 8, 9,10,11) - UPDATE_C_BZ(12,13,14,15) - UPDATE_C_BZ(16,17,18,19) - UPDATE_C_BZ(20,21,22,23) - UPDATE_C_BZ(24,25,26,27) - UPDATE_C_BZ(28,29,30,31) + UPDATE_C( 8, 9,10,11) + UPDATE_C(12,13,14,15) + UPDATE_C(16,17,18,19) + UPDATE_C(20,21,22,23) + UPDATE_C(24,25,26,27) + UPDATE_C(28,29,30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - LEA(RDX, MEM(RCX,RBX,8)) - LEA(RDX, MEM(RDX,RBX,8)) - - MOV(RDI, VAR(offsetPtr)) - VMOVDQA64(ZMM(2), MEM(RDI,0*64)) - VMOVDQA64(ZMM(3), MEM(RDI,1*64)) - VPBROADCASTQ(ZMM(6), RBX) - VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) - VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) - - VCOMISS(XMM(1), XMM(7)) - JE(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_ROW_SCATTERED(12,13,14,15) - UPDATE_C_ROW_SCATTERED(16,17,18,19) - UPDATE_C_ROW_SCATTERED(20,21,22,23) - UPDATE_C_ROW_SCATTERED(24,25,26,27) - UPDATE_C_ROW_SCATTERED(28,29,30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15) - UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19) - UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23) - UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27) - UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 8, 9,10,11) + UPDATE_C_BZ(12,13,14,15) + UPDATE_C_BZ(16,17,18,19) + UPDATE_C_BZ(20,21,22,23) + UPDATE_C_BZ(24,25,26,27) + UPDATE_C_BZ(28,29,30,31) LABEL(END) @@ -560,8 +524,7 @@ void bli_sgemm_skx_asm_32x12_l2( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -570,4 +533,6 @@ void bli_sgemm_skx_asm_32x12_l2( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index aa1aa0e661..4ece5af291 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2018 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -65,6 +65,38 @@ typedef union double d[2]; }v2dd_t; +// return a mask which indicates either: +// - v1 > v2 +// - v1 is NaN and v2 is not +// assumes that idx(v1) > idx(v2) +// all "OQ" comparisons false if either operand NaN +#define CMP256( dt, v1, v2 ) \ + _mm256_or_p##dt( _mm256_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* v1 > v2 || */ \ + _mm256_andnot_p##dt( _mm256_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \ + _mm256_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) */ \ + ) \ + ); + +// return a mask which indicates either: +// - v1 > v2 +// - v1 is NaN and v2 is not +// - v1 == v2 (maybe == NaN) and i1 < i2 +// all "OQ" comparisons false if either operand NaN +#define CMP128( dt, v1, v2, i1, i2 ) \ + _mm_or_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* ( v1 > v2 || */ \ + _mm_andnot_p##dt( _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \ + _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) ) || */ \ + ) \ + ), \ + _mm_and_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_EQ_OQ ), /* ( ( v1 == v2 || */ \ + _mm_and_p##dt( _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ), /* ( isnan(v1) && */ \ + _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ) /* isnan(v2) ) ) && */ \ + ) \ + ), \ + _mm_cmp_p##dt( i1, i2, _CMP_LT_OQ ) /* i1 < i2 ) */ \ + ) \ + ); + // ----------------------------------------------------------------------------- void bli_samaxv_zen_int @@ -122,8 +154,8 @@ void bli_samaxv_zen_int the previous largest, save it and its index. If NaN is encountered, then treat it the same as if it were a valid value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ - if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) ) + behavior mimics that of LAPACK's i?amax(). */ + if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) ) { abs_chi1_max = abs_chi1; i_max_l = i; @@ -157,7 +189,7 @@ void bli_samaxv_zen_int // Get the absolute value of the vector element. x_vec.v = _mm256_andnot_ps( sign_mask.v, x_vec.v ); - mask_vec.v = _mm256_cmp_ps( x_vec.v, max_vec.v, _CMP_GT_OS ); + mask_vec.v = CMP256( s, x_vec.v, max_vec.v ); max_vec.v = _mm256_blendv_ps( max_vec.v, x_vec.v, mask_vec.v ); maxInx_vec.v = _mm256_blendv_ps( maxInx_vec.v, idx_vec.v, mask_vec.v ); @@ -166,33 +198,34 @@ void bli_samaxv_zen_int x += num_vec_elements; } - max_vec_lo.v = _mm256_extractf128_ps( max_vec.v, 0 ); - max_vec_hi.v = _mm256_extractf128_ps( max_vec.v, 1 ); - mask_vec_lo.v = _mm_cmp_ps( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS ); - - max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); - + max_vec_lo.v = _mm256_extractf128_ps( max_vec.v, 0 ); + max_vec_hi.v = _mm256_extractf128_ps( max_vec.v, 1 ); maxInx_vec_lo.v = _mm256_extractf128_ps( maxInx_vec.v, 0 ); maxInx_vec_hi.v = _mm256_extractf128_ps( maxInx_vec.v, 1 ); + + mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v ); + + max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v ); max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 14 ); maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 14 ); - mask_vec_lo.v = _mm_cmp_ps( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS ); + + mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v ); max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v ); - if ( max_vec_lo.f[0] > max_vec_lo.f[1] ) - { - abs_chi1_max = max_vec_lo.f[0]; - i_max_l = maxInx_vec_lo.f[0]; - } - else - { - abs_chi1_max = max_vec_lo.f[1]; - i_max_l = maxInx_vec_lo.f[1]; - } + max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 1 ); + maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 1 ); + + mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v ); + + max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); + maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v ); + + abs_chi1_max = max_vec_lo.f[0]; + i_max_l = maxInx_vec_lo.f[0]; for ( i = n - n_left; i < n; i++ ) { @@ -208,8 +241,8 @@ void bli_samaxv_zen_int the previous largest, save it and its index. If NaN is encountered, then treat it the same as if it were a valid value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ - if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) ) + behavior mimics that of LAPACK's i?amax(). */ + if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) ) { abs_chi1_max = abs_chi1; i_max_l = i; @@ -219,6 +252,12 @@ void bli_samaxv_zen_int } } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + /* Store final index to output variable. */ *i_max = i_max_l; } @@ -280,8 +319,8 @@ void bli_damaxv_zen_int the previous largest, save it and its index. If NaN is encountered, then treat it the same as if it were a valid value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ - if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) ) + behavior mimics that of LAPACK's i?amax(). */ + if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) ) { abs_chi1_max = abs_chi1; i_max_l = i; @@ -315,7 +354,7 @@ void bli_damaxv_zen_int // Get the absolute value of the vector element. x_vec.v = _mm256_andnot_pd( sign_mask.v, x_vec.v ); - mask_vec.v = _mm256_cmp_pd( x_vec.v, max_vec.v, _CMP_GT_OS ); + mask_vec.v = CMP256( d, x_vec.v, max_vec.v ); max_vec.v = _mm256_blendv_pd( max_vec.v, x_vec.v, mask_vec.v ); maxInx_vec.v = _mm256_blendv_pd( maxInx_vec.v, idx_vec.v, mask_vec.v ); @@ -324,26 +363,26 @@ void bli_damaxv_zen_int x += num_vec_elements; } - max_vec_lo.v = _mm256_extractf128_pd( max_vec.v, 0 ); - max_vec_hi.v = _mm256_extractf128_pd( max_vec.v, 1 ); - mask_vec_lo.v = _mm_cmp_pd( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS ); - - max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); - + max_vec_lo.v = _mm256_extractf128_pd( max_vec.v, 0 ); + max_vec_hi.v = _mm256_extractf128_pd( max_vec.v, 1 ); maxInx_vec_lo.v = _mm256_extractf128_pd( maxInx_vec.v, 0 ); maxInx_vec_hi.v = _mm256_extractf128_pd( maxInx_vec.v, 1 ); + + mask_vec_lo.v = CMP128( d, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v ); + + max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); maxInx_vec_lo.v = _mm_blendv_pd( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v ); + + max_vec_hi.v = _mm_permute_pd( max_vec_lo.v, 1 ); + maxInx_vec_hi.v = _mm_permute_pd( maxInx_vec_lo.v, 1 ); + + mask_vec_lo.v = CMP128( d, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v ); - if ( max_vec_lo.d[0] > max_vec_lo.d[1] ) - { - abs_chi1_max = max_vec_lo.d[0]; - i_max_l = maxInx_vec_lo.d[0]; - } - else - { - abs_chi1_max = max_vec_lo.d[1]; - i_max_l = maxInx_vec_lo.d[1]; - } + max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v ); + maxInx_vec_lo.v = _mm_blendv_pd( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v ); + + abs_chi1_max = max_vec_lo.d[0]; + i_max_l = maxInx_vec_lo.d[0]; for ( i = n - n_left; i < n; i++ ) { @@ -357,10 +396,9 @@ void bli_damaxv_zen_int /* If the absolute value of the current element exceeds that of the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ - if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) ) + encountered, return the index of the first NaN. This + behavior mimics that of LAPACK's i?amax(). */ + if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) ) { abs_chi1_max = abs_chi1; i_max_l = i; @@ -370,6 +408,12 @@ void bli_damaxv_zen_int } } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + /* Store final index to output variable. */ *i_max = i_max_l; } diff --git a/kernels/zen/1/bli_axpyv_zen_int.c b/kernels/zen/1/bli_axpyv_zen_int.c index 42668a0a7e..686580b290 100644 --- a/kernels/zen/1/bli_axpyv_zen_int.c +++ b/kernels/zen/1/bli_axpyv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -136,6 +136,13 @@ void bli_saxpyv_zen_int y0 += n_elem_per_reg * n_iter_unroll; } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + const float alphac = *alpha; // If there are leftover iterations, perform them with scalar code. @@ -233,6 +240,13 @@ void bli_daxpyv_zen_int y0 += n_elem_per_reg * n_iter_unroll; } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + const double alphac = *alpha; // If there are leftover iterations, perform them with scalar code. diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index d2780d39c9..873b7da536 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -228,6 +228,13 @@ void bli_saxpyv_zen_int10 y0 += 1*n_elem_per_reg; } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + for ( ; (i + 0) < n; i += 1 ) { *y0 += (*alpha) * (*x0); @@ -427,6 +434,13 @@ void bli_daxpyv_zen_int10 y0 += 1*n_elem_per_reg; } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + for ( ; i < n; i += 1 ) { *y0 += (*alpha) * (*x0); diff --git a/kernels/zen/1/bli_copyv_zen_int.c b/kernels/zen/1/bli_copyv_zen_int.c new file mode 100644 index 0000000000..5fd2b15760 --- /dev/null +++ b/kernels/zen/1/bli_copyv_zen_int.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +// ----------------------------------------------------------------------------- + +void bli_scopyv_zen_int + ( + conj_t conjx, + dim_t n, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + __m256 xv[16]; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 && incy == 1 ) + { +#if 0 + PRAGMA_SIMD + for (i = 0; i < n; i++) + { + y[i] = x[i]; + } +#endif +#if 0 + memcpy(y, x, n << 2); +#endif +#if 1 + + // For loop with n & ~0x7F => n & 0xFFFFFF80 masks the lower bits and results in multiples of 128 + // for example if n = 255 + // n & ~0x7F results in 128: copy from 0 to 128 happens in first loop + // n & ~0x3F results in 192: copy from 128 to 192 happens in second loop + // n & ~0x1F results in 224: copy from 128 to 192 happens in third loop and so on. + for ( i = 0; i < (n & (~0x7F)); i += 128 ) + { + xv[0] = _mm256_loadu_ps(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_ps(x + num_elem_per_reg * 1); + xv[2] = _mm256_loadu_ps(x + num_elem_per_reg * 2); + xv[3] = _mm256_loadu_ps(x + num_elem_per_reg * 3); + xv[4] = _mm256_loadu_ps(x + num_elem_per_reg * 4); + xv[5] = _mm256_loadu_ps(x + num_elem_per_reg * 5); + xv[6] = _mm256_loadu_ps(x + num_elem_per_reg * 6); + xv[7] = _mm256_loadu_ps(x + num_elem_per_reg * 7); + xv[8] = _mm256_loadu_ps(x + num_elem_per_reg * 8); + xv[9] = _mm256_loadu_ps(x + num_elem_per_reg * 9); + xv[10] = _mm256_loadu_ps(x + num_elem_per_reg * 10); + xv[11] = _mm256_loadu_ps(x + num_elem_per_reg * 11); + xv[12] = _mm256_loadu_ps(x + num_elem_per_reg * 12); + xv[13] = _mm256_loadu_ps(x + num_elem_per_reg * 13); + xv[14] = _mm256_loadu_ps(x + num_elem_per_reg * 14); + xv[15] = _mm256_loadu_ps(x + num_elem_per_reg * 15); + + _mm256_storeu_ps(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_ps(y + num_elem_per_reg * 1, xv[1]); + _mm256_storeu_ps(y + num_elem_per_reg * 2, xv[2]); + _mm256_storeu_ps(y + num_elem_per_reg * 3, xv[3]); + _mm256_storeu_ps(y + num_elem_per_reg * 4, xv[4]); + _mm256_storeu_ps(y + num_elem_per_reg * 5, xv[5]); + _mm256_storeu_ps(y + num_elem_per_reg * 6, xv[6]); + _mm256_storeu_ps(y + num_elem_per_reg * 7, xv[7]); + _mm256_storeu_ps(y + num_elem_per_reg * 8, xv[8]); + _mm256_storeu_ps(y + num_elem_per_reg * 9, xv[9]); + _mm256_storeu_ps(y + num_elem_per_reg * 10, xv[10]); + _mm256_storeu_ps(y + num_elem_per_reg * 11, xv[11]); + _mm256_storeu_ps(y + num_elem_per_reg * 12, xv[12]); + _mm256_storeu_ps(y + num_elem_per_reg * 13, xv[13]); + _mm256_storeu_ps(y + num_elem_per_reg * 14, xv[14]); + _mm256_storeu_ps(y + num_elem_per_reg * 15, xv[15]); + + y += 128; + x += 128; + } + for ( ; i < (n & (~0x3F)); i += 64 ) + { + xv[0] = _mm256_loadu_ps(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_ps(x + num_elem_per_reg * 1); + xv[2] = _mm256_loadu_ps(x + num_elem_per_reg * 2); + xv[3] = _mm256_loadu_ps(x + num_elem_per_reg * 3); + xv[4] = _mm256_loadu_ps(x + num_elem_per_reg * 4); + xv[5] = _mm256_loadu_ps(x + num_elem_per_reg * 5); + xv[6] = _mm256_loadu_ps(x + num_elem_per_reg * 6); + xv[7] = _mm256_loadu_ps(x + num_elem_per_reg * 7); + + _mm256_storeu_ps(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_ps(y + num_elem_per_reg * 1, xv[1]); + _mm256_storeu_ps(y + num_elem_per_reg * 2, xv[2]); + _mm256_storeu_ps(y + num_elem_per_reg * 3, xv[3]); + _mm256_storeu_ps(y + num_elem_per_reg * 4, xv[4]); + _mm256_storeu_ps(y + num_elem_per_reg * 5, xv[5]); + _mm256_storeu_ps(y + num_elem_per_reg * 6, xv[6]); + _mm256_storeu_ps(y + num_elem_per_reg * 7, xv[7]); + + y += 64; + x += 64; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + xv[0] = _mm256_loadu_ps(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_ps(x + num_elem_per_reg * 1); + xv[2] = _mm256_loadu_ps(x + num_elem_per_reg * 2); + xv[3] = _mm256_loadu_ps(x + num_elem_per_reg * 3); + + _mm256_storeu_ps(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_ps(y + num_elem_per_reg * 1, xv[1]); + _mm256_storeu_ps(y + num_elem_per_reg * 2, xv[2]); + _mm256_storeu_ps(y + num_elem_per_reg * 3, xv[3]); + + y += 32; + x += 32; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + xv[0] = _mm256_loadu_ps(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_ps(x + num_elem_per_reg * 1); + + _mm256_storeu_ps(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_ps(y + num_elem_per_reg * 1, xv[1]); + + y += 16; + x += 16; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + xv[0] = _mm256_loadu_ps(x + num_elem_per_reg * 0); + _mm256_storeu_ps(y + num_elem_per_reg * 0, xv[0]); + y += 8; + x += 8; + } + for ( ; i < n; ++i ) + { + *y++ = *x++; + } +#endif + } + else + { + for ( dim_t i = 0; i < n; ++i ) + { + *y = *x; + x += incx; + y += incy; + } + } +} + +// ----------------------------------------------------------------------------- + +void bli_dcopyv_zen_int + ( + conj_t conjx, + dim_t n, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 4; + dim_t i = 0; + __m256d xv[16]; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 && incy == 1 ) + { +#if 0 + PRAGMA_SIMD + for (i = 0; i < n; ++i) + { + y[i] = x[i]; + } +#endif +#if 0 + memcpy(y, x, n << 3); +#endif +#if 1 + // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, + // the copy operation will be done for the multiples of 64 + for ( i = 0; i < (n & (~0x3F)); i += 64 ) + { + xv[0] = _mm256_loadu_pd(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_pd(x + num_elem_per_reg * 1); + xv[2] = _mm256_loadu_pd(x + num_elem_per_reg * 2); + xv[3] = _mm256_loadu_pd(x + num_elem_per_reg * 3); + xv[4] = _mm256_loadu_pd(x + num_elem_per_reg * 4); + xv[5] = _mm256_loadu_pd(x + num_elem_per_reg * 5); + xv[6] = _mm256_loadu_pd(x + num_elem_per_reg * 6); + xv[7] = _mm256_loadu_pd(x + num_elem_per_reg * 7); + xv[8] = _mm256_loadu_pd(x + num_elem_per_reg * 8); + xv[9] = _mm256_loadu_pd(x + num_elem_per_reg * 9); + xv[10] = _mm256_loadu_pd(x + num_elem_per_reg * 10); + xv[11] = _mm256_loadu_pd(x + num_elem_per_reg * 11); + xv[12] = _mm256_loadu_pd(x + num_elem_per_reg * 12); + xv[13] = _mm256_loadu_pd(x + num_elem_per_reg * 13); + xv[14] = _mm256_loadu_pd(x + num_elem_per_reg * 14); + xv[15] = _mm256_loadu_pd(x + num_elem_per_reg * 15); + _mm256_storeu_pd(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_pd(y + num_elem_per_reg * 1, xv[1]); + _mm256_storeu_pd(y + num_elem_per_reg * 2, xv[2]); + _mm256_storeu_pd(y + num_elem_per_reg * 3, xv[3]); + _mm256_storeu_pd(y + num_elem_per_reg * 4, xv[4]); + _mm256_storeu_pd(y + num_elem_per_reg * 5, xv[5]); + _mm256_storeu_pd(y + num_elem_per_reg * 6, xv[6]); + _mm256_storeu_pd(y + num_elem_per_reg * 7, xv[7]); + _mm256_storeu_pd(y + num_elem_per_reg * 8, xv[8]); + _mm256_storeu_pd(y + num_elem_per_reg * 9, xv[9]); + _mm256_storeu_pd(y + num_elem_per_reg * 10, xv[10]); + _mm256_storeu_pd(y + num_elem_per_reg * 11, xv[11]); + _mm256_storeu_pd(y + num_elem_per_reg * 12, xv[12]); + _mm256_storeu_pd(y + num_elem_per_reg * 13, xv[13]); + _mm256_storeu_pd(y + num_elem_per_reg * 14, xv[14]); + _mm256_storeu_pd(y + num_elem_per_reg * 15, xv[15]); + y += num_elem_per_reg * 16; + x += num_elem_per_reg * 16; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + xv[0] = _mm256_loadu_pd(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_pd(x + num_elem_per_reg * 1); + xv[2] = _mm256_loadu_pd(x + num_elem_per_reg * 2); + xv[3] = _mm256_loadu_pd(x + num_elem_per_reg * 3); + xv[4] = _mm256_loadu_pd(x + num_elem_per_reg * 4); + xv[5] = _mm256_loadu_pd(x + num_elem_per_reg * 5); + xv[6] = _mm256_loadu_pd(x + num_elem_per_reg * 6); + xv[7] = _mm256_loadu_pd(x + num_elem_per_reg * 7); + + _mm256_storeu_pd(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_pd(y + num_elem_per_reg * 1, xv[1]); + _mm256_storeu_pd(y + num_elem_per_reg * 2, xv[2]); + _mm256_storeu_pd(y + num_elem_per_reg * 3, xv[3]); + _mm256_storeu_pd(y + num_elem_per_reg * 4, xv[4]); + _mm256_storeu_pd(y + num_elem_per_reg * 5, xv[5]); + _mm256_storeu_pd(y + num_elem_per_reg * 6, xv[6]); + _mm256_storeu_pd(y + num_elem_per_reg * 7, xv[7]); + + y += num_elem_per_reg * 8; + x += num_elem_per_reg * 8; + } + for ( ; i < (n & (~0xF)); i += 16 ) + { + xv[0] = _mm256_loadu_pd(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_pd(x + num_elem_per_reg * 1); + xv[2] = _mm256_loadu_pd(x + num_elem_per_reg * 2); + xv[3] = _mm256_loadu_pd(x + num_elem_per_reg * 3); + + _mm256_storeu_pd(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_pd(y + num_elem_per_reg * 1, xv[1]); + _mm256_storeu_pd(y + num_elem_per_reg * 2, xv[2]); + _mm256_storeu_pd(y + num_elem_per_reg * 3, xv[3]); + + y += num_elem_per_reg * 4; + x += num_elem_per_reg * 4; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + xv[0] = _mm256_loadu_pd(x + num_elem_per_reg * 0); + xv[1] = _mm256_loadu_pd(x + num_elem_per_reg * 1); + + _mm256_storeu_pd(y + num_elem_per_reg * 0, xv[0]); + _mm256_storeu_pd(y + num_elem_per_reg * 1, xv[1]); + + y += num_elem_per_reg * 2; + x += num_elem_per_reg * 2; + } + for ( ; i < (n & (~0x03)); i += 4 ) + { + xv[0] = _mm256_loadu_pd(x + num_elem_per_reg * 0); + _mm256_storeu_pd(y + num_elem_per_reg * 0, xv[0]); + y += num_elem_per_reg; + x += num_elem_per_reg; + } + for ( ; i < n; ++i ) + { + *y++ = *x++; + } +#endif + } + else + { + for ( i = 0; i < n; ++i ) + { + *y = *x; + + x += incx; + y += incy; + } + } +} + diff --git a/kernels/zen/1/bli_dotv_zen_int.c b/kernels/zen/1/bli_dotv_zen_int.c index 1c87a0f87a..01022d353a 100644 --- a/kernels/zen/1/bli_dotv_zen_int.c +++ b/kernels/zen/1/bli_dotv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -151,6 +151,13 @@ void bli_sdotv_zen_int rho0 += rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + // If there are leftover iterations, perform them with scalar code. for ( i = 0; i < n_left; ++i ) { @@ -265,6 +272,13 @@ void bli_ddotv_zen_int // Accumulate the final rho vector into a single scalar result. rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + // If there are leftover iterations, perform them with scalar code. for ( i = 0; i < n_left; ++i ) { diff --git a/kernels/zen/1/bli_dotv_zen_int10.c b/kernels/zen/1/bli_dotv_zen_int10.c index 79fdde969c..8c445849b0 100644 --- a/kernels/zen/1/bli_dotv_zen_int10.c +++ b/kernels/zen/1/bli_dotv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2020, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -73,11 +73,11 @@ void bli_sdotv_zen_int10 float* restrict x0; float* restrict y0; - float rho0; + float rho0 = 0.0; __m256 xv[10]; __m256 yv[10]; - v8sf_t rhov[2]; + v8sf_t rhov[10]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) ) @@ -96,8 +96,16 @@ void bli_sdotv_zen_int10 { rhov[0].v = _mm256_setzero_ps(); rhov[1].v = _mm256_setzero_ps(); - - for ( i = 0; (i + 79) < n; i += 80 ) + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + rhov[4].v = _mm256_setzero_ps(); + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + rhov[8].v = _mm256_setzero_ps(); + rhov[9].v = _mm256_setzero_ps(); + + for ( i = 0 ; (i + 79) < n; i += 80 ) { // 80 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -124,19 +132,25 @@ void bli_sdotv_zen_int10 rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[7].v ); + rhov[8].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[8].v ); + rhov[9].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[9].v ); x0 += 10*n_elem_per_reg; y0 += 10*n_elem_per_reg; } + rhov[0].v += rhov[5].v; + rhov[1].v += rhov[6].v; + rhov[2].v += rhov[7].v; + rhov[3].v += rhov[8].v; + rhov[4].v += rhov[9].v; + for ( ; (i + 39) < n; i += 40 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -153,34 +167,17 @@ void bli_sdotv_zen_int10 rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[0].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); x0 += 5*n_elem_per_reg; y0 += 5*n_elem_per_reg; } - for ( ; (i + 31) < n; i += 32 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v ); - - x0 += 4*n_elem_per_reg; - y0 += 4*n_elem_per_reg; - } + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + rhov[0].v += rhov[4].v; for ( ; (i + 15) < n; i += 16 ) { @@ -197,6 +194,8 @@ void bli_sdotv_zen_int10 y0 += 2*n_elem_per_reg; } + rhov[0].v += rhov[1].v; + for ( ; (i + 7) < n; i += 8 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -211,19 +210,21 @@ void bli_sdotv_zen_int10 for ( ; (i + 0) < n; i += 1 ) { - rhov[0].f[0] += x0[i] * y0[i]; + rho0 += (*x0) * (*y0); + x0 += 1; + y0 += 1; } - v8sf_t onev; - - onev.v = _mm256_set1_ps( 1.0f ); - - rhov[0].v = _mm256_dp_ps( rhov[0].v, onev.v, 0xf1 ); - rhov[1].v = _mm256_dp_ps( rhov[1].v, onev.v, 0xf1 ); + rho0 += rhov[0].f[0] + rhov[0].f[1] + + rhov[0].f[2] + rhov[0].f[3] + + rhov[0].f[4] + rhov[0].f[5] + + rhov[0].f[6] + rhov[0].f[7]; - // Manually add the results from above to finish the sum. - rho0 += rhov[0].f[0] + rhov[0].f[4]; - rho0 += rhov[1].f[0] + rhov[1].f[4]; + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); } else { @@ -263,11 +264,11 @@ void bli_ddotv_zen_int10 double* restrict x0; double* restrict y0; - double rho0; + double rho0 = 0.0; __m256d xv[10]; __m256d yv[10]; - v4df_t rhov[2]; + v4df_t rhov[10]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) ) @@ -286,6 +287,14 @@ void bli_ddotv_zen_int10 { rhov[0].v = _mm256_setzero_pd(); rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + rhov[8].v = _mm256_setzero_pd(); + rhov[9].v = _mm256_setzero_pd(); for ( i = 0; (i + 39) < n; i += 40 ) { @@ -314,19 +323,25 @@ void bli_ddotv_zen_int10 rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v ); + rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v ); + rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v ); x0 += 10*n_elem_per_reg; y0 += 10*n_elem_per_reg; } + rhov[0].v += rhov[5].v; + rhov[1].v += rhov[6].v; + rhov[2].v += rhov[7].v; + rhov[3].v += rhov[8].v; + rhov[4].v += rhov[9].v; + for ( ; (i + 19) < n; i += 20 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -343,14 +358,16 @@ void bli_ddotv_zen_int10 rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[0].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); x0 += 5*n_elem_per_reg; y0 += 5*n_elem_per_reg; } + rhov[0].v += rhov[4].v; + for ( ; (i + 15) < n; i += 16 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -365,13 +382,16 @@ void bli_ddotv_zen_int10 rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); x0 += 4*n_elem_per_reg; y0 += 4*n_elem_per_reg; } + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + for ( ; (i + 7) < n; i += 8 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -387,6 +407,8 @@ void bli_ddotv_zen_int10 y0 += 2*n_elem_per_reg; } + rhov[0].v += rhov[1].v; + for ( ; (i + 3) < n; i += 4 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -401,12 +423,20 @@ void bli_ddotv_zen_int10 for ( ; (i + 0) < n; i += 1 ) { - rhov[0].d[0] += x0[i] * y0[i]; + rho0 += (*x0) * (*y0); + + x0 += 1; + y0 += 1; } // Manually add the results from above to finish the sum. - rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3]; - rho0 += rhov[1].d[0] + rhov[1].d[1] + rhov[1].d[2] + rhov[1].d[3]; + rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3]; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); } else { diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index 53b582b773..99ea517104 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -157,6 +157,13 @@ void bli_sdotxv_zen_int rho0 = rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + // If there are leftover iterations, perform them with scalar code. for ( i = 0; i < n_left; ++i ) { @@ -277,6 +284,13 @@ void bli_ddotxv_zen_int // Accumulate the final rho vector into a single scalar result. rho0 = rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + // If there are leftover iterations, perform them with scalar code. for ( i = 0; i < n_left; ++i ) { diff --git a/kernels/zen/1/bli_scalv_zen_int.c b/kernels/zen/1/bli_scalv_zen_int.c index 3c58212b07..9f76e88e18 100644 --- a/kernels/zen/1/bli_scalv_zen_int.c +++ b/kernels/zen/1/bli_scalv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 32812d3df2..c8488890fd 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -80,9 +80,11 @@ void bli_sscalv_zen_int10 // If alpha is zero, use setv. if ( PASTEMAC(s,eq0)( *alpha ) ) { - float* zero = bli_s0; - ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SETV_KER, cntx ); + float* zero = bli_s0; + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SETV_KER, cntx ); f ( BLIS_NO_CONJUGATE, @@ -91,6 +93,7 @@ void bli_sscalv_zen_int10 x, incx, cntx ); + return; } @@ -270,8 +273,11 @@ void bli_dscalv_zen_int10 // If alpha is zero, use setv. if ( PASTEMAC(d,eq0)( *alpha ) ) { - double* zero = bli_d0; - dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SETV_KER, cntx ); + double* zero = bli_d0; + + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SETV_KER, cntx ); f ( @@ -281,6 +287,7 @@ void bli_dscalv_zen_int10 x, incx, cntx ); + return; } @@ -433,3 +440,33 @@ void bli_dscalv_zen_int10 } } +// ----------------------------------------------------------------------------- + +// +// NOTE: This function definition is provided as a placeholder in order to allow +// function names of scalv kernels to be hard-coded in bli_gemv_unf_var2_amd.c. +// + +void bli_cscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + const num_t dt = BLIS_SCOMPLEX; + + cscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( dt, BLIS_SCALV_KER, cntx ); + + f + ( + conjalpha, + n, + alpha, + x, incx, + cntx + ); +} + diff --git a/kernels/zen/1/bli_setv_zen_int.c b/kernels/zen/1/bli_setv_zen_int.c new file mode 100644 index 0000000000..16e02c94da --- /dev/null +++ b/kernels/zen/1/bli_setv_zen_int.c @@ -0,0 +1,228 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +// ----------------------------------------------------------------------------- + +void bli_ssetv_zen_int + ( + conj_t conjalpha, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + __m256 alphav; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 ) + { + alphav = _mm256_broadcast_ss( alpha ); + + // For loop with n & ~0x7F => n & 0xFFFFFF80 masks the lower bits and results in multiples of 128 + // for example if n = 255 + // n & ~0x7F results in 128: copy from 0 to 128 happens in first loop + // n & ~0x3F results in 192: copy from 128 to 192 happens in second loop + // n & ~0x1F results in 224: copy from 128 to 192 happens in third loop and so on. + for ( i = 0; i < (n & (~0x7F)); i += 128 ) + { + _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 3, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 4, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 5, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 6, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 7, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 8, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 9, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 10, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 11, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 12, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 13, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 14, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 15, alphav); + + x += 128; + } + for ( ; i < (n & (~0x3F)); i += 64 ) + { + _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 3, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 4, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 5, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 6, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 7, alphav); + + x += 64; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 3, alphav); + + x += 32; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); + + x += 16; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); + x += 8; + } + for ( ; i < n; ++i ) + { + *x++ = *alpha; + } + } + else + { + for ( dim_t i = 0; i < n; ++i ) + { + *x = *alpha; + x += incx; + } + } +} + +void bli_dsetv_zen_int + ( + conj_t conjalpha, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 4; + dim_t i = 0; + __m256d alphav; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + alphav = _mm256_broadcast_sd( alpha ); + + // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, + // the copy operation will be done for the multiples of 64 + for ( i = 0; i < (n & (~0x3F)); i += 64 ) + { + _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 3, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 4, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 5, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 6, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 7, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 8, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 9, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 10, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 11, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 12, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 13, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 14, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 15, alphav); + + x += num_elem_per_reg * 16; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 3, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 4, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 5, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 6, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 7, alphav); + + x += num_elem_per_reg * 8; + } + for ( ; i < (n & (~0xF)); i += 16 ) + { + _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 3, alphav); + + x += num_elem_per_reg * 4; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); + + x += num_elem_per_reg * 2; + } + for ( ; i < (n & (~0x03)); i += 4 ) + { + _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); + x += num_elem_per_reg; + } + for ( ; i < n; ++i ) + { + *x++ = *alpha; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + *x = *alpha; + + x += incx; + } + } +} + diff --git a/kernels/zen/1/bli_swapv_zen_int8.c b/kernels/zen/1/bli_swapv_zen_int8.c new file mode 100644 index 0000000000..aa7a6e3398 --- /dev/null +++ b/kernels/zen/1/bli_swapv_zen_int8.c @@ -0,0 +1,344 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* Union data structure to access AVX registers +* One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +// ----------------------------------------------------------------------------- + +void bli_sswapv_zen_int8 + ( + dim_t n, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + + const dim_t n_elem_per_reg = 8; + dim_t i = 0; + + float* restrict x0; + float* restrict y0; + + __m256 xv[8]; + __m256 yv[8]; + + // If the vector dimension is zero, return early. + if ( bli_zero_dim1( n ) ) return; + + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + for ( i = 0; ( i + 63 ) < n; i += 64 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), yv[0]); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), yv[1]); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), yv[2]); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), yv[3]); + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), yv[4]); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), yv[5]); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), yv[6]); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), yv[7]); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), xv[0]); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), xv[1]); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), xv[2]); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), xv[3]); + _mm256_storeu_ps( (y0 + 4*n_elem_per_reg), xv[4]); + _mm256_storeu_ps( (y0 + 5*n_elem_per_reg), xv[5]); + _mm256_storeu_ps( (y0 + 6*n_elem_per_reg), xv[6]); + _mm256_storeu_ps( (y0 + 7*n_elem_per_reg), xv[7]); + + x0 += 8*n_elem_per_reg; + y0 += 8*n_elem_per_reg; + } + + for ( ; ( i + 31 ) < n; i += 32 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), xv[0]); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), xv[1]); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), xv[2]); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), xv[3]); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), yv[0]); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), yv[1]); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), yv[2]); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), yv[3]); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + } + + for ( ; ( i + 15 ) < n; i += 16 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), xv[0]); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), xv[1]); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), yv[0]); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), yv[1]); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + for ( ; ( i + 7 ) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), yv[0]); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), xv[0]); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + PASTEMAC(s,swaps)( x[i], y[i] ); + } + } + else + { + for ( i = 0; i < n; ++i ) + { + PASTEMAC(s,swaps)( (*x0), (*y0) ); + + x0 += incx; + y0 += incy; + } + } + +} + +//-------------------------------------------------------------------------------- + +void bli_dswapv_zen_int8 + ( + dim_t n, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 4; + dim_t i = 0; + + double* restrict x0; + double* restrict y0; + + __m256d xv[8]; + __m256d yv[8]; + + // If the vector dimension is zero, return early. + if ( bli_zero_dim1( n ) ) return; + + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + for ( ; ( i + 31 ) < n; i += 32 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), yv[0]); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), yv[1]); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), yv[2]); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), yv[3]); + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), yv[4]); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), yv[5]); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), yv[6]); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), yv[7]); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), xv[0]); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), xv[1]); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), xv[2]); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), xv[3]); + _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), xv[4]); + _mm256_storeu_pd( (y0 + 5*n_elem_per_reg), xv[5]); + _mm256_storeu_pd( (y0 + 6*n_elem_per_reg), xv[6]); + _mm256_storeu_pd( (y0 + 7*n_elem_per_reg), xv[7]); + + x0 += 8*n_elem_per_reg; + y0 += 8*n_elem_per_reg; + } + + for ( ; ( i + 15 ) < n; i += 16 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), xv[0]); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), xv[1]); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), xv[2]); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), xv[3]); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), yv[0]); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), yv[1]); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), yv[2]); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), yv[3]); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + } + + for ( ; ( i + 7 ) < n; i += 8 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), xv[0]); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), xv[1]); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), yv[0]); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), yv[1]); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + for ( ; ( i + 3 ) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), xv[0]); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), yv[0]); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + PASTEMAC(d,swaps)( x[i], y[i] ); + } + } + else + { + for ( i = 0; i < n; ++i ) + { + PASTEMAC(d,swaps)( (*x0), (*y0) ); + + x0 += incx; + y0 += incy; + } + } +} + diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c new file mode 100644 index 0000000000..5ddb56ac57 --- /dev/null +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -0,0 +1,277 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + + + void bli_caxpyf_zen_int_4 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + inc_t fuse_fac = 4; + inc_t i; + + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm10; + __m256 ymm12, ymm13; + + float* ap[4]; + float* y0 = (float*)y; + + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + + dim_t setPlusOne = 1; + + if ( bli_is_conj(conja) ) + { + setPlusOne = -1; + } + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_ceq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + scomplex* a1 = a + (0 )*inca + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y + (0 )*incy; + scomplex alpha_chi1; + + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + if(bli_is_noconj(conjx)) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + scomplex *pchi0 = x + 0*incx ; + scomplex *pchi1 = x + 1*incx ; + scomplex *pchi2 = x + 2*incx ; + scomplex *pchi3 = x + 3*incx ; + + bli_ccopycjs( conjx, *pchi0, chi0 ); + bli_ccopycjs( conjx, *pchi1, chi1 ); + bli_ccopycjs( conjx, *pchi2, chi2 ); + bli_ccopycjs( conjx, *pchi3, chi3 ); + } + + // Scale each chi scalar by alpha. + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + lda *= 2; + incx *= 2; + incy *= 2; + inca *= 2; + + ap[0] = (float*)a; + ap[1] = (float*)a + lda; + ap[2] = ap[1] + lda; + ap[3] = ap[2] + lda; + + if( inca == 2 && incy == 2 ) + { + inc_t n1 = m/4; + inc_t n2 = m%4; + + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + + // broadcast real & imag parts of 4 elements of x + ymm0 = _mm256_broadcast_ss(&chi0.real); // real part of x0 + ymm1 = _mm256_broadcast_ss(&chi0.imag); // imag part of x0 + ymm2 = _mm256_broadcast_ss(&chi1.real); // real part of x1 + ymm3 = _mm256_broadcast_ss(&chi1.imag); // imag part of x1 + ymm4 = _mm256_broadcast_ss(&chi2.real); // real part of x2 + ymm5 = _mm256_broadcast_ss(&chi2.imag); // imag part of x2 + ymm6 = _mm256_broadcast_ss(&chi3.real); // real part of x3 + ymm7 = _mm256_broadcast_ss(&chi3.imag); // imag part of x3 + + for(i = 0; i < n1; i++) + { + //load first two columns of A + ymm8 = _mm256_loadu_ps(ap[0] + 0); + ymm10 = _mm256_loadu_ps(ap[1] + 0); + + ymm12 = _mm256_mul_ps(ymm8, ymm0); + ymm13 = _mm256_mul_ps(ymm8, ymm1); + + ymm12 = _mm256_fmadd_ps(ymm10, ymm2, ymm12); + ymm13 = _mm256_fmadd_ps(ymm10, ymm3, ymm13); + + //load 3rd and 4th columns of A + ymm8 = _mm256_loadu_ps(ap[2] + 0); + ymm10 = _mm256_loadu_ps(ap[3] + 0); + + ymm12 = _mm256_fmadd_ps(ymm8, ymm4, ymm12); + ymm13 = _mm256_fmadd_ps(ymm8, ymm5, ymm13); + + ymm12 = _mm256_fmadd_ps(ymm10, ymm6, ymm12); + ymm13 = _mm256_fmadd_ps(ymm10, ymm7, ymm13); + + //load Y vector + ymm10 = _mm256_loadu_ps(y0 + 0); + + if(bli_is_noconj(conja)) + { + //printf("Inside no conj if\n"); + ymm13 = _mm256_permute_ps(ymm13, 0xB1); + ymm8 = _mm256_addsub_ps(ymm12, ymm13); + } + else + { + ymm12 = _mm256_permute_ps(ymm12, 0xB1); + ymm8 = _mm256_addsub_ps(ymm13, ymm12); + ymm8 = _mm256_permute_ps(ymm8, 0xB1); + } + + ymm12 = _mm256_add_ps(ymm8, ymm10); + + _mm256_storeu_ps((float*)(y0), ymm12); + + y0 += 8; + ap[0] += 8; + ap[1] += 8; + ap[2] += 8; + ap[3] += 8; + } + + // If there are leftover iterations, perform them with scalar code. + + for ( i = 0; (i + 0) < n2 ; ++i ) + { + + scomplex y0c = *(scomplex*)y0; + + const scomplex a0c = *(scomplex*)ap[0]; + const scomplex a1c = *(scomplex*)ap[1]; + const scomplex a2c = *(scomplex*)ap[2]; + const scomplex a3c = *(scomplex*)ap[3]; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + + *(scomplex*)y0 = y0c; + + ap[0] += 2; + ap[1] += 2; + ap[2] += 2; + ap[3] += 2; + y0 += 2; + } + //PASTEMAC(c,fprintm)(stdout, "Y after A*x in axpyf",m, 1, (scomplex*)y, 1, 1, "%4.1f", ""); + + } + else + { + for (i = 0 ; (i + 0) < m ; ++i ) + { + scomplex y0c = *(scomplex*)y0; + const scomplex a0c = *(scomplex*)ap[0]; + const scomplex a1c = *(scomplex*)ap[1]; + const scomplex a2c = *(scomplex*)ap[2]; + const scomplex a3c = *(scomplex*)ap[3]; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + + *(scomplex*)y0 = y0c; + + ap[0] += inca; + ap[1] += inca; + ap[2] += inca; + ap[3] += inca; + y0 += incy; + } + } +} diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c new file mode 100644 index 0000000000..15a64d5966 --- /dev/null +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -0,0 +1,1231 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* Union data structure to access AVX registers +* One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + __m128d xmm[2]; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + + +void bli_saxpyf_zen_int_5 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 5; + + const dim_t n_elem_per_reg = 8; + const dim_t n_iter_unroll = 2; + + dim_t i; + + float* restrict a0; + float* restrict a1; + float* restrict a2; + float* restrict a3; + float* restrict a4; + + float* restrict y0; + + v8sf_t chi0v, chi1v, chi2v, chi3v; + v8sf_t chi4v; + + v8sf_t a00v, a01v, a02v, a03v; + v8sf_t a04v; + + v8sf_t a10v, a11v, a12v, a13v; + v8sf_t a14v; + + v8sf_t y0v, y1v; + + float chi0, chi1, chi2, chi3; + float chi4; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_seq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { + if(cntx == NULL) cntx = bli_gks_query_cntx(); + saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + float* a1 = a + (0 )*inca + (i )*lda; + float* chi1 = x + (i )*incx; + float* y1 = y + (0 )*incy; + float alpha_chi1; + + bli_scopycjs( conjx, *chi1, alpha_chi1 ); + bli_sscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + + // Scale each chi scalar by alpha. + bli_sscals( *alpha, chi0 ); + bli_sscals( *alpha, chi1 ); + bli_sscals( *alpha, chi2 ); + bli_sscals( *alpha, chi3 ); + bli_sscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_ss( &chi0 ); + chi1v.v = _mm256_broadcast_ss( &chi1 ); + chi2v.v = _mm256_broadcast_ss( &chi2 ); + chi3v.v = _mm256_broadcast_ss( &chi3 ); + chi4v.v = _mm256_broadcast_ss( &chi4 ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + for ( i = 0; (i + 15) < m; i += 16 ) + { + // Load the input values. + y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + a00v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_ps( a0 + 1*n_elem_per_reg ); + + a01v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_ps( a1 + 1*n_elem_per_reg ); + + a02v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_ps( a2 + 1*n_elem_per_reg ); + + a03v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_ps( a3 + 1*n_elem_per_reg ); + + a04v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg ); + a14v.v = _mm256_loadu_ps( a4 + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_ps( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_ps( a10v.v, chi0v.v, y1v.v ); + + y0v.v = _mm256_fmadd_ps( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_ps( a11v.v, chi1v.v, y1v.v ); + + y0v.v = _mm256_fmadd_ps( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_ps( a12v.v, chi2v.v, y1v.v ); + + y0v.v = _mm256_fmadd_ps( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_ps( a13v.v, chi3v.v, y1v.v ); + + y0v.v = _mm256_fmadd_ps( a04v.v, chi4v.v, y0v.v ); + y1v.v = _mm256_fmadd_ps( a14v.v, chi4v.v, y1v.v ); + + + // Store the output. + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += n_iter_unroll * n_elem_per_reg; + a0 += n_iter_unroll * n_elem_per_reg; + a1 += n_iter_unroll * n_elem_per_reg; + a2 += n_iter_unroll * n_elem_per_reg; + a3 += n_iter_unroll * n_elem_per_reg; + a4 += n_iter_unroll * n_elem_per_reg; + } + + for( ; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + a00v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg ); + a01v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg ); + a02v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg ); + a03v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg ); + a04v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg ); + + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_ps( a00v.v, chi0v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a01v.v, chi1v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a02v.v, chi2v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a03v.v, chi3v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a04v.v, chi4v.v, y0v.v ); + + // Store the output. + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg; + a0 += n_elem_per_reg; + a1 += n_elem_per_reg; + a2 += n_elem_per_reg; + a3 += n_elem_per_reg; + a4 += n_elem_per_reg; + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + const float a4c = *a4; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + y0c += chi4 * a4c; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + } + else + { + for ( i = 0; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + const float a4c = *a4; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + y0c += chi4 * a4c; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + y0 += incy; + } + + } +} + + +// ----------------------------------------------------------------------------- + +void bli_daxpyf_zen_int_5 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 5; + + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 2; + + dim_t i; + + double* restrict a0; + double* restrict a1; + double* restrict a2; + double* restrict a3; + double* restrict a4; + + double* restrict y0; + + v4df_t chi0v, chi1v, chi2v, chi3v; + v4df_t chi4v; + + v4df_t a00v, a01v, a02v, a03v; + v4df_t a04v; + + v4df_t a10v, a11v, a12v, a13v; + v4df_t a14v; + + v4df_t y0v, y1v; + + double chi0, chi1, chi2, chi3; + double chi4; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + + // Scale each chi scalar by alpha. + bli_dscals( *alpha, chi0 ); + bli_dscals( *alpha, chi1 ); + bli_dscals( *alpha, chi2 ); + bli_dscals( *alpha, chi3 ); + bli_dscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0 ); + chi1v.v = _mm256_broadcast_sd( &chi1 ); + chi2v.v = _mm256_broadcast_sd( &chi2 ); + chi3v.v = _mm256_broadcast_sd( &chi3 ); + chi4v.v = _mm256_broadcast_sd( &chi4 ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + for ( i = 0; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + + a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + a14v.v = _mm256_loadu_pd( a4 + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += n_iter_unroll * n_elem_per_reg; + a0 += n_iter_unroll * n_elem_per_reg; + a1 += n_iter_unroll * n_elem_per_reg; + a2 += n_iter_unroll * n_elem_per_reg; + a3 += n_iter_unroll * n_elem_per_reg; + a4 += n_iter_unroll * n_elem_per_reg; + } + + for( ; (i + 3) < m; i += 4 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg; + a0 += n_elem_per_reg; + a1 += n_elem_per_reg; + a2 += n_elem_per_reg; + a3 += n_elem_per_reg; + a4 += n_elem_per_reg; + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + y0c += chi4 * a4c; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + } + else + { + for ( i = 0; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + y0c += chi4 * a4c; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + y0 += incy; + } + + } +} + +// ----------------------------------------------------------------------------- + +static void bli_daxpyf_zen_int_16x2 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 2; + + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + + double* restrict a0; + double* restrict a1; + + double* restrict y0; + + v4df_t chi0v, chi1v; + + v4df_t a00v, a01v; + + v4df_t a10v, a11v; + + v4df_t a20v, a21v; + + v4df_t a30v, a31v; + + v4df_t y0v, y1v, y2v, y3v; + + double chi0, chi1; + + v2df_t a40v, a41v; + + v2df_t y4v; + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + + + // Scale each chi scalar by alpha. + bli_dscals( *alpha, chi0 ); + bli_dscals( *alpha, chi1 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0 ); + chi1v.v = _mm256_broadcast_sd( &chi1 ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + for ( i = 0; (i + 15) < m; i += 16 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + _mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v ); + + y0 += n_iter_unroll * n_elem_per_reg; + a0 += n_iter_unroll * n_elem_per_reg; + a1 += n_iter_unroll * n_elem_per_reg; + } + + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + + y0 += 3 * n_elem_per_reg; + a0 += 3 * n_elem_per_reg; + a1 += 3 * n_elem_per_reg; + } + for ( ; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += 2 * n_elem_per_reg; + a0 += 2 * n_elem_per_reg; + a1 += 2 * n_elem_per_reg; + } + + for ( ; (i + 3) < m; i += 4 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg; + a0 += n_elem_per_reg; + a1 += n_elem_per_reg; + } + + for ( ; (i + 1) < m; i += 2 ) + { + // Load the input values. + y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg ); + + a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg ); + + a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v ); + + // Store the output. + _mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v ); + + y0 += 2; + a0 += 2; + a1 += 2; + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + y0 += 1; + } + } + else + { + for ( i = 0; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + y0 += incy; + } + + } +} + +// ----------------------------------------------------------------------------- +void bli_daxpyf_zen_int_16x4 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 4; + + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + + double* restrict a0; + double* restrict a1; + double* restrict a2; + double* restrict a3; + + double* restrict y0; + + v4df_t chi0v, chi1v, chi2v, chi3v; + + v4df_t a00v, a01v, a02v, a03v; + + v4df_t a10v, a11v, a12v, a13v; + + v4df_t a20v, a21v, a22v, a23v; + + v4df_t a30v, a31v, a32v, a33v; + + v4df_t y0v, y1v, y2v, y3v; + + double chi0, chi1, chi2, chi3; + + v2df_t y4v; + + v2df_t a40v, a41v, a42v, a43v; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { + if(cntx == NULL) cntx = bli_gks_query_cntx(); + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + + // Scale each chi scalar by alpha. + bli_dscals( *alpha, chi0 ); + bli_dscals( *alpha, chi1 ); + bli_dscals( *alpha, chi2 ); + bli_dscals( *alpha, chi3 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0 ); + chi1v.v = _mm256_broadcast_sd( &chi1 ); + chi2v.v = _mm256_broadcast_sd( &chi2 ); + chi3v.v = _mm256_broadcast_sd( &chi3 ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + for ( i = 0; (i + 15) < m; i += 16 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg ); + a32v.v = _mm256_loadu_pd( a2 + 3*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg ); + a33v.v = _mm256_loadu_pd( a3 + 3*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a32v.v, chi2v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a33v.v, chi3v.v, y3v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + _mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v ); + + y0 += n_iter_unroll * n_elem_per_reg; + a0 += n_iter_unroll * n_elem_per_reg; + a1 += n_iter_unroll * n_elem_per_reg; + a2 += n_iter_unroll * n_elem_per_reg; + a3 += n_iter_unroll * n_elem_per_reg; + } + + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + + y0 += 3 * n_elem_per_reg; + a0 += 3 * n_elem_per_reg; + a1 += 3 * n_elem_per_reg; + a2 += 3 * n_elem_per_reg; + a3 += 3 * n_elem_per_reg; + } + + for ( ; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += 2 * n_elem_per_reg; + a0 += 2 * n_elem_per_reg; + a1 += 2 * n_elem_per_reg; + a2 += 2 * n_elem_per_reg; + a3 += 2 * n_elem_per_reg; + } + + + for ( ; (i + 3) < m; i += 4) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg; + a0 += n_elem_per_reg; + a1 += n_elem_per_reg; + a2 += n_elem_per_reg; + a3 += n_elem_per_reg; + } +#if 1 + for ( ; (i + 1) < m; i += 2) + { + + // Load the input values. + y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg ); + + a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg ); + + a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg ); + + a42v.v = _mm_loadu_pd( a2 + 0*n_elem_per_reg ); + + a43v.v = _mm_loadu_pd( a3 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a42v.v, chi2v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a43v.v, chi3v.xmm[0], y4v.v ); + + // Store the output. + _mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v ); + + y0 += 2; + a0 += 2; + a1 += 2; + a2 += 2; + a3 += 2; + } +#endif + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + + y0 += 1; + } + } + else + { + for ( i = 0; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + + y0 += incy; + } + + } +} + + diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index 13bda01e4c..b958600ce6 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 4f07340356..e40c785d85 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 776c5a0fe7..b04ffea580 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,20 +36,21 @@ #include "xmmintrin.h" #include "blis.h" +#define AOCL_DTL_TRACE_ENTRY(x) ; +#define AOCL_DTL_TRACE_EXIT(x) ; +#define AOCL_DTL_TRACE_EXIT_ERR(x,y) ; + #ifdef BLIS_ENABLE_SMALL_MATRIX #define MR 32 #define D_MR (MR >> 1) #define NR 3 +#define D_BLIS_SMALL_MATRIX_K_THRES_ROME 256 #define BLIS_ENABLE_PREFETCH -#define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES) -static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64))); #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 ) #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2) #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2) -#define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES) -static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64))); #define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. #define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR). static err_t bli_sgemm_small @@ -111,7 +112,10 @@ err_t bli_gemm_small cntl_t* cntl ) { + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + #ifdef BLIS_ENABLE_MULTITHREADING + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; #endif // If alpha is zero, scale by beta and return. @@ -128,7 +132,7 @@ err_t bli_gemm_small return BLIS_INVALID_ROW_STRIDE; } - num_t dt = ((*c).info & (0x7 << 0)); + num_t dt = bli_obj_dt(c); if (bli_obj_has_trans( a )) { @@ -157,6 +161,7 @@ err_t bli_gemm_small return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl); } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; @@ -172,69 +177,136 @@ static err_t bli_sgemm_small cntl_t* cntl ) { + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + gint_t L = M * N; + + // when N is equal to 1 call GEMV instead of GEMM + if (N == 1) + { + bli_gemv + ( + alpha, + a, + b, + beta, + c + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + return BLIS_SUCCESS; + } - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . - // printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K); - if (((M * N) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) - || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) + if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) + || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; - int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc = bli_obj_col_stride( c ); // column stride of matrix C - int row_idx, col_idx, k; - float *A = a->buffer; // pointer to elements of Matrix A - float *B = b->buffer; // pointer to elements of Matrix B - float *C = c->buffer; // pointer to elements of Matrix C + float *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A + float *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B + float *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C float *tA = A, *tB = B, *tC = C;//, *tA_pack; - float *tA_packed; // temprorary pointer to hold packed A memory pointer - int row_idx_packed; //packed A memory row index - int lda_packed; //lda of packed A - int col_idx_start; //starting index after A matrix is packed. + float *tA_packed; // temporary pointer to hold packed A memory pointer + + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. dim_t tb_inc_row = 1; // row stride of matrix B dim_t tb_inc_col = ldb; // column stride of matrix B - __m256 ymm4, ymm5, ymm6, ymm7; + + __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10, ymm11; __m256 ymm12, ymm13, ymm14, ymm15; __m256 ymm0, ymm1, ymm2, ymm3; - int n_remainder; // If the N is non multiple of 3.(N%3) - int m_remainder; // If the M is non multiple of 32.(M%32) - - float *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); - int required_packing_A = 1; - - // when N is equal to 1 call GEMV instead of GEMM - if (N == 1) - { - bli_gemv - ( - alpha, - a, - b, - beta, - c - ); - return BLIS_SUCCESS; + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 32.(M%32) + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + float *A_pack = NULL; + rntm_t rntm; + + const num_t dt_exec = bli_obj_dt( c ); + float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + + /*Beta Zero Check*/ + bool is_beta_non_zero=0; + if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){ + is_beta_non_zero = 1; } - //update the pointer math if matrix B needs to be transposed. - if (bli_obj_has_trans( b )) - { + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) { tb_inc_col = 1; //switch row and column strides tb_inc_row = ldb; } - if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM)) + /* + * This function was using global array to pack part of A input when needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each invocation. + * Since the buffer size is too big or stack and doing malloc every time will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_pba_rntm_set_pba( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initialization + siz_t buffer_size = bli_pool_block_size(bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); + + // Based on the available memory in the buffer we will decide if + // we want to do packing or not. + // + // This kernel assumes that "A" will be un-packged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N <= 3) || (((MR * K) << 2) > buffer_size)) { required_packing_A = 0; } + else + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size); +#endif + // Get the buffer from the pool, if there is no pool with + // required size, it will be created. + bli_pba_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + /* * The computation loop runs for MRxN columns of C matrix, thus * accessing the MRxK A matrix data and KxNR B matrix data. @@ -243,7 +315,6 @@ static err_t bli_sgemm_small // Process MR rows of C matrix at a time. for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR) { - col_idx_start = 0; tA_packed = A; row_idx_packed = row_idx; @@ -338,7 +409,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -354,15 +424,39 @@ static err_t bli_sgemm_small ymm14 = _mm256_mul_ps(ymm14, ymm0); ymm15 = _mm256_mul_ps(ymm15, ymm0); - // multiply C by beta and accumulate col 1. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + + float* ttC = tC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); + + ttC += ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); _mm256_storeu_ps(tC + 16, ymm6); @@ -370,14 +464,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); @@ -385,14 +471,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 3. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -482,7 +560,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -498,15 +575,37 @@ static err_t bli_sgemm_small ymm14 = _mm256_mul_ps(ymm14, ymm0); ymm15 = _mm256_mul_ps(ymm15, ymm0); - // multiply C by beta and accumulate col 1. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + float* ttC = tC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); + ttC = ttC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); _mm256_storeu_ps(tC + 16, ymm6); @@ -514,14 +613,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); @@ -529,14 +620,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 3. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -595,7 +678,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm8 = _mm256_mul_ps(ymm8, ymm0); @@ -608,29 +690,34 @@ static err_t bli_sgemm_small ymm15 = _mm256_mul_ps(ymm15, ymm0); // multiply C by beta and accumulate, col 1. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); - _mm256_storeu_ps(tC + 0, ymm8); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); + + float* ttC = tC +ldc; + // multiply C by beta and accumulate, col 2. + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } + _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); _mm256_storeu_ps(tC + 24, ymm11); - - // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -679,7 +766,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm12 = _mm256_mul_ps(ymm12, ymm0); @@ -687,15 +773,19 @@ static err_t bli_sgemm_small ymm14 = _mm256_mul_ps(ymm14, ymm0); ymm15 = _mm256_mul_ps(ymm15, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC + 0); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } _mm256_storeu_ps(tC + 0, ymm12); _mm256_storeu_ps(tC + 8, ymm13); @@ -767,7 +857,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -780,37 +869,43 @@ static err_t bli_sgemm_small ymm13 = _mm256_mul_ps(ymm13, ymm0); ymm14 = _mm256_mul_ps(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + float* ttC = tC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ttC += ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); _mm256_storeu_ps(tC + 16, ymm6); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -861,7 +956,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm8 = _mm256_mul_ps(ymm8, ymm0); @@ -871,25 +965,33 @@ static err_t bli_sgemm_small ymm13 = _mm256_mul_ps(ymm13, ymm0); ymm14 = _mm256_mul_ps(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - _mm256_storeu_ps(tC + 0, ymm8); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + + float* ttC = tC +ldc; + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + } + + _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -933,21 +1035,23 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm12 = _mm256_mul_ps(ymm12, ymm0); ymm13 = _mm256_mul_ps(ymm13, ymm0); ymm14 = _mm256_mul_ps(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC + 0); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + } _mm256_storeu_ps(tC + 0, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -1000,7 +1104,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -1010,29 +1113,35 @@ static err_t bli_sgemm_small ymm8 = _mm256_mul_ps(ymm8, ymm0); ymm9 = _mm256_mul_ps(ymm9, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + float* ttC = tC + ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + ttC += ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); _mm256_storeu_ps(tC, ymm6); _mm256_storeu_ps(tC + 8, ymm7); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); @@ -1075,7 +1184,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -1083,20 +1191,25 @@ static err_t bli_sgemm_small ymm6 = _mm256_mul_ps(ymm6, ymm0); ymm7 = _mm256_mul_ps(ymm7, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + float* ttC = tC + ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); _mm256_storeu_ps(tC, ymm6); _mm256_storeu_ps(tC + 8, ymm7); @@ -1134,16 +1247,19 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); ymm4 = _mm256_mul_ps(ymm4, ymm0); ymm5 = _mm256_mul_ps(ymm5, ymm0); // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); @@ -1188,28 +1304,30 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); ymm5 = _mm256_mul_ps(ymm5, ymm0); ymm6 = _mm256_mul_ps(ymm6, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + ldc); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 2*ldc); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + } _mm256_storeu_ps(tC, ymm4); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); _mm256_storeu_ps(tC, ymm5); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); _mm256_storeu_ps(tC, ymm6); } n_remainder = N - col_idx; @@ -1243,21 +1361,23 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); ymm5 = _mm256_mul_ps(ymm5, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + ldc); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(tC, ymm4); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); _mm256_storeu_ps(tC, ymm5); col_idx += 2; @@ -1290,13 +1410,15 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); - ymm4 = _mm256_mul_ps(ymm4, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + if(is_beta_non_zero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + } _mm256_storeu_ps(tC, ymm4); } @@ -1309,7 +1431,7 @@ static err_t bli_sgemm_small // to handle this case. if ((m_remainder) && (lda > 7)) { - float f_temp[8]; + float f_temp[8] = {0.0}; for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) { @@ -1370,7 +1492,9 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(is_beta_non_zero){ + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -1383,7 +1507,9 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(is_beta_non_zero){ + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + } _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { @@ -1396,7 +1522,9 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + if(is_beta_non_zero){ + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + } _mm256_storeu_ps(f_temp, ymm9); for (int i = 0; i < m_remainder; i++) { @@ -1454,7 +1582,9 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(is_beta_non_zero){ + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -1467,8 +1597,10 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); - _mm256_storeu_ps(f_temp, ymm7); + if(is_beta_non_zero){ + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + } + _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { tC[i] = f_temp[i]; @@ -1509,7 +1641,6 @@ static err_t bli_sgemm_small ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); // multiply C by beta and accumulate. ymm5 = _mm256_mul_ps(ymm5, ymm0); @@ -1519,7 +1650,10 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(is_beta_non_zero){ + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -1550,15 +1684,36 @@ static err_t bli_sgemm_small } result *= (*alpha_cast); - (*tC) = (*tC) * (*beta_cast) + result; + if(is_beta_non_zero){ + (*tC) = (*tC) * (*beta_cast) + result; + }else{ + (*tC) = result; + } } } } + + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s) ) { + +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_sgemm_small(): releasing mem pool block\n" ); +#endif + bli_pba_release(&rntm, + &local_mem_buf_A_s); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } else - return BLIS_NONCONFORMAL_DIMENSIONS; - + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } }; @@ -1574,29 +1729,57 @@ static err_t bli_dgemm_small ) { - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + gint_t L = M * N; - // If alpha is zero, scale by beta and return. - // printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K); - if (((M * N) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) - || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) + // when N is equal to 1 call GEMV instead of GEMM + if (N == 1) { + bli_gemv + ( + alpha, + a, + b, + beta, + c + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } - int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc = bli_obj_col_stride( c ); // column stride of matrix C - int row_idx, col_idx, k; - double *A = a->buffer; // pointer to elements of Matrix A - double *B = b->buffer; // pointer to elements of Matrix B - double *C = c->buffer; // pointer to elements of Matrix C + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME)))) +#else + if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) + || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) +#endif + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + double *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B + double *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C double *tA = A, *tB = B, *tC = C;//, *tA_pack; double *tA_packed; // temprorary pointer to hold packed A memory pointer - int row_idx_packed; //packed A memory row index - int lda_packed; //lda of packed A - int col_idx_start; //starting index after A matrix is packed. + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. dim_t tb_inc_row = 1; // row stride of matrix B dim_t tb_inc_col = ldb; // column stride of matrix B __m256d ymm4, ymm5, ymm6, ymm7; @@ -1604,27 +1787,17 @@ static err_t bli_dgemm_small __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm0, ymm1, ymm2, ymm3; - int n_remainder; // If the N is non multiple of 3.(N%3) - int m_remainder; // If the M is non multiple of 16.(M%16) + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) double *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); - int required_packing_A = 1; + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); - // when N is equal to 1 call GEMV instead of GEMM - if (N == 1) - { - bli_gemv - ( - alpha, - a, - b, - beta, - c - ); - return BLIS_SUCCESS; - } + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + rntm_t rntm; //update the pointer math if matrix B needs to be transposed. if (bli_obj_has_trans( b )) @@ -1633,10 +1806,71 @@ static err_t bli_dgemm_small tb_inc_row = ldb; } - if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM)) + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each invocation. + * Since the buffer size is too big or stack and doing malloc every time will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_pba_rntm_set_pba( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) { required_packing_A = 0; } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", buffer_size); +#endif + // Get the buffer from the pool. + bli_pba_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + /* * The computation loop runs for D_MRxN columns of C matrix, thus * accessing the D_MRxK A matrix data and KxNR B matrix data. @@ -1756,45 +1990,56 @@ static err_t bli_dgemm_small ymm14 = _mm256_mul_pd(ymm14, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - // multiply C by beta and accumulate col 1. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate, col 2. + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + ttC += ldc; + + // multiply C by beta and accumulate, col 3. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } _mm256_storeu_pd(tC, ymm4); _mm256_storeu_pd(tC + 4, ymm5); _mm256_storeu_pd(tC + 8, ymm6); _mm256_storeu_pd(tC + 12, ymm7); - // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + _mm256_storeu_pd(tC, ymm8); _mm256_storeu_pd(tC + 4, ymm9); _mm256_storeu_pd(tC + 8, ymm10); _mm256_storeu_pd(tC + 12, ymm11); - // multiply C by beta and accumulate, col 3. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + _mm256_storeu_pd(tC, ymm12); _mm256_storeu_pd(tC + 4, ymm13); _mm256_storeu_pd(tC + 8, ymm14); @@ -1900,45 +2145,54 @@ static err_t bli_dgemm_small ymm14 = _mm256_mul_pd(ymm14, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - // multiply C by beta and accumulate col 1. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + // multiply C by beta and accumulate, col 2. + double* ttC = tC + ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 3. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } _mm256_storeu_pd(tC, ymm4); _mm256_storeu_pd(tC + 4, ymm5); _mm256_storeu_pd(tC + 8, ymm6); _mm256_storeu_pd(tC + 12, ymm7); - // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + _mm256_storeu_pd(tC, ymm8); _mm256_storeu_pd(tC + 4, ymm9); _mm256_storeu_pd(tC + 8, ymm10); _mm256_storeu_pd(tC + 12, ymm11); - // multiply C by beta and accumulate, col 3. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + _mm256_storeu_pd(tC, ymm12); _mm256_storeu_pd(tC + 4, ymm13); _mm256_storeu_pd(tC + 8, ymm14); @@ -2009,35 +2263,42 @@ static err_t bli_dgemm_small ymm14 = _mm256_mul_pd(ymm14, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - // multiply C by beta and accumulate, col 1. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate, col 1. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 2. + double *ttC = tC + ldc; + + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + _mm256_storeu_pd(tC + 0, ymm8); _mm256_storeu_pd(tC + 4, ymm9); _mm256_storeu_pd(tC + 8, ymm10); _mm256_storeu_pd(tC + 12, ymm11); - // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + _mm256_storeu_pd(tC, ymm12); _mm256_storeu_pd(tC + 4, ymm13); _mm256_storeu_pd(tC + 8, ymm14); _mm256_storeu_pd(tC + 12, ymm15); - col_idx += 2; } // if the N is not multiple of 3. @@ -2089,15 +2350,18 @@ static err_t bli_dgemm_small ymm14 = _mm256_mul_pd(ymm14, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } _mm256_storeu_pd(tC + 0, ymm12); _mm256_storeu_pd(tC + 4, ymm13); @@ -2182,41 +2446,50 @@ static err_t bli_dgemm_small ymm13 = _mm256_mul_pd(ymm13, ymm0); ymm14 = _mm256_mul_pd(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + // multiply C by beta and accumulate. + double *ttC = tC +ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + // multiply C by beta and accumulate. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } _mm256_storeu_pd(tC, ymm4); _mm256_storeu_pd(tC + 4, ymm5); _mm256_storeu_pd(tC + 8, ymm6); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + _mm256_storeu_pd(tC, ymm8); _mm256_storeu_pd(tC + 4, ymm9); _mm256_storeu_pd(tC + 8, ymm10); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + _mm256_storeu_pd(tC, ymm12); _mm256_storeu_pd(tC + 4, ymm13); _mm256_storeu_pd(tC + 8, ymm14); - } n_remainder = N - col_idx; // if the N is not multiple of 3. @@ -2273,25 +2546,34 @@ static err_t bli_dgemm_small ymm13 = _mm256_mul_pd(ymm13, ymm0); ymm14 = _mm256_mul_pd(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + double *ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } _mm256_storeu_pd(tC + 0, ymm8); _mm256_storeu_pd(tC + 4, ymm9); _mm256_storeu_pd(tC + 8, ymm10); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + _mm256_storeu_pd(tC, ymm12); _mm256_storeu_pd(tC + 4, ymm13); _mm256_storeu_pd(tC + 8, ymm14); @@ -2342,14 +2624,18 @@ static err_t bli_dgemm_small ymm13 = _mm256_mul_pd(ymm13, ymm0); ymm14 = _mm256_mul_pd(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } _mm256_storeu_pd(tC + 0, ymm12); _mm256_storeu_pd(tC + 4, ymm13); _mm256_storeu_pd(tC + 8, ymm14); @@ -2412,29 +2698,39 @@ static err_t bli_dgemm_small ymm8 = _mm256_mul_pd(ymm8, ymm0); ymm9 = _mm256_mul_pd(ymm9, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ttC += ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + } + _mm256_storeu_pd(tC, ymm4); _mm256_storeu_pd(tC + 4, ymm5); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); _mm256_storeu_pd(tC, ymm6); _mm256_storeu_pd(tC + 4, ymm7); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); _mm256_storeu_pd(tC, ymm8); _mm256_storeu_pd(tC + 4, ymm9); @@ -2485,20 +2781,26 @@ static err_t bli_dgemm_small ymm6 = _mm256_mul_pd(ymm6, ymm0); ymm7 = _mm256_mul_pd(ymm7, ymm0); + if(is_beta_non_zero) + { // multiply C by beta and accumulate. ymm2 = _mm256_loadu_pd(tC); ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); ymm2 = _mm256_loadu_pd(tC + 4); ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); + + double* ttC = tC + ldc; // multiply C by beta and accumulate. - tC += ldc; - ymm2 = _mm256_loadu_pd(tC); + ymm2 = _mm256_loadu_pd(ttC); ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 4); + ymm2 = _mm256_loadu_pd(ttC + 4); ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + tC += ldc; _mm256_storeu_pd(tC, ymm6); _mm256_storeu_pd(tC + 4, ymm7); @@ -2541,11 +2843,14 @@ static err_t bli_dgemm_small ymm4 = _mm256_mul_pd(ymm4, ymm0); ymm5 = _mm256_mul_pd(ymm5, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + } _mm256_storeu_pd(tC, ymm4); _mm256_storeu_pd(tC + 4, ymm5); @@ -2598,21 +2903,30 @@ static err_t bli_dgemm_small ymm5 = _mm256_mul_pd(ymm5, ymm0); ymm6 = _mm256_mul_pd(ymm6, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + ttC += ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + } _mm256_storeu_pd(tC, ymm4); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); _mm256_storeu_pd(tC, ymm5); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); _mm256_storeu_pd(tC, ymm6); } n_remainder = N - col_idx; @@ -2652,15 +2966,21 @@ static err_t bli_dgemm_small ymm4 = _mm256_mul_pd(ymm4, ymm0); ymm5 = _mm256_mul_pd(ymm5, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + } _mm256_storeu_pd(tC, ymm4); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_pd(tC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); _mm256_storeu_pd(tC, ymm5); col_idx += 2; @@ -2697,9 +3017,13 @@ static err_t bli_dgemm_small ymm4 = _mm256_mul_pd(ymm4, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + } _mm256_storeu_pd(tC, ymm4); } @@ -2712,7 +3036,7 @@ static err_t bli_dgemm_small // to handle this case. if ((m_remainder) && (lda > 3)) { - double f_temp[8]; + double f_temp[8] = {0.0}; for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) { @@ -2767,44 +3091,52 @@ static err_t bli_dgemm_small ymm7 = _mm256_mul_pd(ymm7, ymm0); ymm9 = _mm256_mul_pd(ymm9, ymm0); - - for (int i = 0; i < m_remainder; i++) + if(is_beta_non_zero) { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - tC += ldc; - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - tC += ldc; - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - _mm256_storeu_pd(f_temp, ymm9); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; + double* ttC = tC + ldc; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = ttC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ttC += ldc; + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = ttC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); } + _mm256_storeu_pd(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + _mm256_storeu_pd(f_temp, ymm7); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + _mm256_storeu_pd(f_temp, ymm9); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } } n_remainder = N - col_idx; // if the N is not multiple of 3. @@ -2852,12 +3184,25 @@ static err_t bli_dgemm_small ymm5 = _mm256_mul_pd(ymm5, ymm0); ymm7 = _mm256_mul_pd(ymm7, ymm0); - for (int i = 0; i < m_remainder; i++) + if(is_beta_non_zero) { - f_temp[i] = tC[i]; + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + double* ttC = tC + ldc; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = ttC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); _mm256_storeu_pd(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -2865,12 +3210,6 @@ static err_t bli_dgemm_small } tC += ldc; - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); _mm256_storeu_pd(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { @@ -2917,12 +3256,16 @@ static err_t bli_dgemm_small // multiply C by beta and accumulate. ymm5 = _mm256_mul_pd(ymm5, ymm0); - for (int i = 0; i < m_remainder; i++) + if(is_beta_non_zero) { - f_temp[i] = tC[i]; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); _mm256_storeu_pd(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -2953,16 +3296,33 @@ static err_t bli_dgemm_small } result *= (*alpha_cast); - (*tC) = (*tC) * (*beta_cast) + result; + if(is_beta_non_zero) + (*tC) = (*tC) * (*beta_cast) + result; + else + (*tC) = result; } } } + + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small(): releasing mem pool block\n" ); +#endif + bli_pba_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - - + } }; static err_t bli_sgemm_small_atbn @@ -2976,16 +3336,21 @@ static err_t bli_sgemm_small_atbn cntl_t* cntl ) { - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_length( b ); // number of rows of Matrix B - int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc = bli_obj_col_stride( c ); // column stride of matrix C + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_length( b ); // number of rows of Matrix B + + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + int row_idx = 0, col_idx = 0, k; - float *A = a->buffer; // pointer to matrix A elements, stored in row major format - float *B = b->buffer; // pointer to matrix B elements, stored in column major format - float *C = c->buffer; // pointer to matrix C elements, stored in column major format + + float *A = bli_obj_buffer_at_off(a); // pointer to matrix A elements, stored in row major format + float *B = bli_obj_buffer_at_off(b); // pointer to matrix B elements, stored in column major format + float *C = bli_obj_buffer_at_off(c); // pointer to matrix C elements, stored in column major format float *tA = A, *tB = B, *tC = C; @@ -2994,10 +3359,17 @@ static err_t bli_sgemm_small_atbn __m256 ymm12, ymm13, ymm14, ymm15; __m256 ymm0, ymm1, ymm2, ymm3; - float result, scratch[8]; - float *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); + float result; + float scratch[8] = {0.0}; + const num_t dt_exec = bli_obj_dt( c ); + float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + + /*Beta Zero Check*/ + bool is_beta_non_zero=0; + if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){ + is_beta_non_zero = 1; + } // The non-copy version of the A^T GEMM gives better performance for the small M cases. // The threshold is controlled by BLIS_ATBN_M_THRES @@ -3110,28 +3482,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm4); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero){ + tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm7 = _mm256_hadd_ps(ymm7, ymm7); ymm7 = _mm256_hadd_ps(ymm7, ymm7); _mm256_storeu_ps(scratch, ymm7); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero){ + tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm10 = _mm256_hadd_ps(ymm10, ymm10); ymm10 = _mm256_hadd_ps(ymm10, ymm10); _mm256_storeu_ps(scratch, ymm10); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero){ + tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm13 = _mm256_hadd_ps(ymm13, ymm13); ymm13 = _mm256_hadd_ps(ymm13, ymm13); _mm256_storeu_ps(scratch, ymm13); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); + if(is_beta_non_zero){ + tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } tC += ldc; ymm5 = _mm256_hadd_ps(ymm5, ymm5); @@ -3139,28 +3527,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm5); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero){ + tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm8 = _mm256_hadd_ps(ymm8, ymm8); ymm8 = _mm256_hadd_ps(ymm8, ymm8); _mm256_storeu_ps(scratch, ymm8); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero){ + tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm11 = _mm256_hadd_ps(ymm11, ymm11); ymm11 = _mm256_hadd_ps(ymm11, ymm11); _mm256_storeu_ps(scratch, ymm11); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero){ + tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm14 = _mm256_hadd_ps(ymm14, ymm14); ymm14 = _mm256_hadd_ps(ymm14, ymm14); _mm256_storeu_ps(scratch, ymm14); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); + if(is_beta_non_zero){ + tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } tC += ldc; ymm6 = _mm256_hadd_ps(ymm6, ymm6); @@ -3168,28 +3572,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm6); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero){ + tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm9 = _mm256_hadd_ps(ymm9, ymm9); ymm9 = _mm256_hadd_ps(ymm9, ymm9); _mm256_storeu_ps(scratch, ymm9); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero){ + tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm12 = _mm256_hadd_ps(ymm12, ymm12); ymm12 = _mm256_hadd_ps(ymm12, ymm12); _mm256_storeu_ps(scratch, ymm12); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero){ + tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm15 = _mm256_hadd_ps(ymm15, ymm15); ymm15 = _mm256_hadd_ps(ymm15, ymm15); _mm256_storeu_ps(scratch, ymm15); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); + if(is_beta_non_zero){ + tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } } } @@ -3273,29 +3693,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm4); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero){ + tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm7 = _mm256_hadd_ps(ymm7, ymm7); ymm7 = _mm256_hadd_ps(ymm7, ymm7); _mm256_storeu_ps(scratch, ymm7); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero){ + tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm10 = _mm256_hadd_ps(ymm10, ymm10); ymm10 = _mm256_hadd_ps(ymm10, ymm10); _mm256_storeu_ps(scratch, ymm10); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero){ + tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm13 = _mm256_hadd_ps(ymm13, ymm13); ymm13 = _mm256_hadd_ps(ymm13, ymm13); _mm256_storeu_ps(scratch, ymm13); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); - + if(is_beta_non_zero){ + tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } } } processed_row = row_idx; @@ -3345,16 +3780,26 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm4); result = scratch[0] + scratch[4]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero){ + tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } } } } - + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; + } } static err_t bli_dgemm_small_atbn @@ -3368,16 +3813,23 @@ static err_t bli_dgemm_small_atbn cntl_t* cntl ) { - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_length( b ); // number of rows of Matrix B - int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc = bli_obj_col_stride( c ); // column stride of matrix C - int row_idx = 0, col_idx = 0, k; - double *A = a->buffer; // pointer to matrix A elements, stored in row major format - double *B = b->buffer; // pointer to matrix B elements, stored in column major format - double *C = c->buffer; // pointer to matrix C elements, stored in column major format + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_length( b ); // number of rows of Matrix B + + // The non-copy version of the A^T GEMM gives better performance for the small M cases. + // The threshold is controlled by BLIS_ATBN_M_THRES + if (M <= BLIS_ATBN_M_THRES) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx = 0, col_idx = 0, k; + double *A = bli_obj_buffer_at_off(a); // pointer to matrix A elements, stored in row major format + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B elements, stored in column major format + double *C = bli_obj_buffer_at_off(c); // pointer to matrix C elements, stored in column major format double *tA = A, *tB = B, *tC = C; @@ -3386,19 +3838,23 @@ static err_t bli_dgemm_small_atbn __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm0, ymm1, ymm2, ymm3; - double result, scratch[8]; + double result; + double scratch[8] = {0.0}; double *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); - // The non-copy version of the A^T GEMM gives better performance for the small M cases. - // The threshold is controlled by BLIS_ATBN_M_THRES - if (M <= BLIS_ATBN_M_THRES) + //check if beta is zero + //if true, we need to perform C = alpha * (A * B) + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta,&BLIS_ZERO)) + is_beta_non_zero = 1; + + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) { - for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) { - for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) - { tA = A + row_idx * lda; tB = B + col_idx * ldb; tC = C + col_idx * ldc + row_idx; @@ -3501,77 +3957,111 @@ static err_t bli_dgemm_small_atbn _mm256_storeu_pd(scratch, ymm4); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero) + tC[0] = result + tC[0] * (*beta_cast); + else + tC[0] = result; ymm7 = _mm256_hadd_pd(ymm7, ymm7); _mm256_storeu_pd(scratch, ymm7); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero) + tC[1] = result + tC[1] * (*beta_cast); + else + tC[1] = result; ymm10 = _mm256_hadd_pd(ymm10, ymm10); _mm256_storeu_pd(scratch, ymm10); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero) + tC[2] = result + tC[2] * (*beta_cast); + else + tC[2] = result; ymm13 = _mm256_hadd_pd(ymm13, ymm13); _mm256_storeu_pd(scratch, ymm13); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); - + if(is_beta_non_zero) + tC[3] = result + tC[3] * (*beta_cast); + else + tC[3] = result; tC += ldc; ymm5 = _mm256_hadd_pd(ymm5, ymm5); _mm256_storeu_pd(scratch, ymm5); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero) + tC[0] = result + tC[0] * (*beta_cast); + else + tC[0] = result; ymm8 = _mm256_hadd_pd(ymm8, ymm8); _mm256_storeu_pd(scratch, ymm8); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero) + tC[1] = result + tC[1] * (*beta_cast); + else + tC[1] = result; ymm11 = _mm256_hadd_pd(ymm11, ymm11); _mm256_storeu_pd(scratch, ymm11); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero) + tC[2] = result + tC[2] * (*beta_cast); + else + tC[2] = result; ymm14 = _mm256_hadd_pd(ymm14, ymm14); _mm256_storeu_pd(scratch, ymm14); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); + if(is_beta_non_zero) + tC[3] = result + tC[3] * (*beta_cast); + else + tC[3] = result; - tC += ldc; ymm6 = _mm256_hadd_pd(ymm6, ymm6); _mm256_storeu_pd(scratch, ymm6); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero) + tC[0] = result + tC[0] * (*beta_cast); + else + tC[0] = result; ymm9 = _mm256_hadd_pd(ymm9, ymm9); _mm256_storeu_pd(scratch, ymm9); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero) + tC[1] = result + tC[1] * (*beta_cast); + else + tC[1] = result; ymm12 = _mm256_hadd_pd(ymm12, ymm12); _mm256_storeu_pd(scratch, ymm12); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero) + tC[2] = result + tC[2] * (*beta_cast); + else + tC[2] = result; ymm15 = _mm256_hadd_pd(ymm15, ymm15); _mm256_storeu_pd(scratch, ymm15); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); + if(is_beta_non_zero) + tC[3] = result + tC[3] * (*beta_cast); + else + tC[3] = result; } } @@ -3653,26 +4143,37 @@ static err_t bli_dgemm_small_atbn _mm256_storeu_pd(scratch, ymm4); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); + if(is_beta_non_zero) + tC[0] = result + tC[0] * (*beta_cast); + else + tC[0] = result; ymm7 = _mm256_hadd_pd(ymm7, ymm7); _mm256_storeu_pd(scratch, ymm7); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[1] = result + tC[1] * (*beta_cast); + if(is_beta_non_zero) + tC[1] = result + tC[1] * (*beta_cast); + else + tC[1] = result; ymm10 = _mm256_hadd_pd(ymm10, ymm10); _mm256_storeu_pd(scratch, ymm10); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[2] = result + tC[2] * (*beta_cast); + if(is_beta_non_zero) + tC[2] = result + tC[2] * (*beta_cast); + else + tC[2] = result; ymm13 = _mm256_hadd_pd(ymm13, ymm13); _mm256_storeu_pd(scratch, ymm13); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[3] = result + tC[3] * (*beta_cast); - + if(is_beta_non_zero) + tC[3] = result + tC[3] * (*beta_cast); + else + tC[3] = result; } } processed_row = row_idx; @@ -3721,17 +4222,24 @@ static err_t bli_dgemm_small_atbn _mm256_storeu_pd(scratch, ymm4); result = scratch[0] + scratch[2]; result *= (*alpha_cast); - tC[0] = result + tC[0] * (*beta_cast); - + if(is_beta_non_zero) + tC[0] = result + tC[0] * (*beta_cast); + else + tC[0] = result; } } } - + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - return BLIS_NONCONFORMAL_DIMENSIONS; + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } } - #endif diff --git a/kernels/zen/3/bli_gemmt_small.c b/kernels/zen/3/bli_gemmt_small.c new file mode 100644 index 0000000000..f2fd88de7b --- /dev/null +++ b/kernels/zen/3/bli_gemmt_small.c @@ -0,0 +1,4210 @@ +/* + +BLIS +An object-based framework for developing high-performance BLAS-like +libraries. + +Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +- Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +- Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. +- Neither the name of The University of Texas at Austin nor the names +of its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ENABLE_SMALL_MATRIX + +#define MR 32 +#define D_MR (MR >> 1) +#define NR 3 + +#define BLIS_ENABLE_PREFETCH +#define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES) +static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64))); +static float C_pack[F_SCRATCH_DIM] __attribute__((aligned(64))); +#define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 ) +#define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2) +#define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2) +#define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES) +static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64))); +static double D_C_pack[D_SCRATCH_DIM] __attribute__((aligned(64))); +#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. +#define AT_MR 4 // The kernel dimension of the A transpose GEMMT kernel.(AT_MR * NR). +static err_t bli_sgemmt_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +static err_t bli_dgemmt_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +static err_t bli_sgemmt_small_atbn + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +static err_t bli_dgemmt_small_atbn + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); +/* +* The bli_gemmt_small function will use the +* custom MRxNR kernels, to perform the computation. +* The custom kernels are used if the [M * N] < 240 * 240 +*/ +err_t bli_gemmt_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + // FGVZ: This code was originally in bli_gemmt_front(). However, it really + // fits more naturally here within the bli_gemmt_small() function. This + // becomes a bit more obvious now that the code is here, as it contains + // cpp macros such as BLIS_SMALL_MATRIX_A_THRES_M_GEMMT, which are specific + // to this implementation. + if ( bli_obj_has_trans( a ) ) + { + // Continue with small implementation. + ; + } + else if ( ( bli_obj_length( a ) <= BLIS_SMALL_MATRIX_A_THRES_M_GEMMT && + bli_obj_width( a ) < BLIS_SMALL_MATRIX_A_THRES_N_GEMMT ) || + ( bli_obj_length( a ) < BLIS_SMALL_MATRIX_A_THRES_M_GEMMT && + bli_obj_width( a ) <= BLIS_SMALL_MATRIX_A_THRES_N_GEMMT ) ) + { + // Continue with small implementation. + ; + } + else + { + // Reject the problem and return to large code path. + return BLIS_FAILURE; + } + +#ifdef BLIS_ENABLE_MULTITHREADING + return BLIS_NOT_YET_IMPLEMENTED; +#endif + // If alpha is zero, scale by beta and return. + if (bli_obj_equals(alpha, &BLIS_ZERO)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + // if row major format return. + if ((bli_obj_row_stride( a ) != 1) || + (bli_obj_row_stride( b ) != 1) || + (bli_obj_row_stride( c ) != 1)) + { + return BLIS_INVALID_ROW_STRIDE; + } + + num_t dt = ((*c).info & (0x7 << 0)); + + if (bli_obj_has_trans( a )) + { + if (bli_obj_has_notrans( b )) + { + if (dt == BLIS_FLOAT) + { + return bli_sgemmt_small_atbn(alpha, a, b, beta, c, cntx, cntl); + } + else if (dt == BLIS_DOUBLE) + { + return bli_dgemmt_small_atbn(alpha, a, b, beta, c, cntx, cntl); + } + } + + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (dt == BLIS_DOUBLE) + { + return bli_dgemmt_small(alpha, a, b, beta, c, cntx, cntl); + } + + if (dt == BLIS_FLOAT) + { + return bli_sgemmt_small(alpha, a, b, beta, c, cntx, cntl); + } + + return BLIS_NOT_YET_IMPLEMENTED; +}; + + +static err_t bli_sgemmt_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + + int M = bli_obj_length( c ); // number of rows of Matrix C + int N = bli_obj_width( c ); // number of columns of Matrix C + int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + int L = M * N; + + if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) + || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) + { + + int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C + int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C + int row_idx, col_idx, k; + int rs_matC = bli_obj_row_stride( c ); + int rsc = 1; + float *A = a->buffer; // pointer to elements of Matrix A + float *B = b->buffer; // pointer to elements of Matrix B + float *C = C_pack; // pointer to elements of Matrix C + float *matCbuf = c->buffer; + + float *tA = A, *tB = B, *tC = C;//, *tA_pack; + float *tA_packed; // temprorary pointer to hold packed A memory pointer + int row_idx_packed; //packed A memory row index + int lda_packed; //lda of packed A + int col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm0, ymm1, ymm2, ymm3; + + int n_remainder; // If the N is non multiple of 3.(N%3) + int m_remainder; // If the M is non multiple of 32.(M%32) + + float *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = (alpha->buffer); + beta_cast = (beta->buffer); + int required_packing_A = 1; + + // when N is equal to 1 call GEMV instead of GEMMT + if (N == 1) + { + bli_gemv + ( + alpha, + a, + b, + beta, + c + ); + return BLIS_SUCCESS; + } + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM)) + { + required_packing_A = 0; + } + /* + * The computation loop runs for MRxN columns of C matrix, thus + * accessing the MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension MRxNR. + */ + // Process MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR) + { + + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + // This is the part of the pack and compute optimization. + // During the first column iteration, we store the accessed A matrix into + // contiguous static memory. This helps to keep te A matrix in Cache and + // aviods the TLB misses. + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = A_pack; + +#if 0//def BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 16), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0); +#endif + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing MR x K + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + _mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + _mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + _mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14); + + ymm3 = _mm256_loadu_ps(tA + 24); + _mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A + // ymm7 += ymm0 * ymm3; + ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); + // ymm11 += ymm1 * ymm3; + ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); + // ymm15 += ymm2 * ymm3; + ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); + + tA += lda; + tA_packed += MR; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm6 = _mm256_mul_ps(ymm6, ymm0); + ymm7 = _mm256_mul_ps(ymm7, ymm0); + ymm8 = _mm256_mul_ps(ymm8, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + ymm10 = _mm256_mul_ps(ymm10, ymm0); + ymm11 = _mm256_mul_ps(ymm11, ymm0); + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + ymm15 = _mm256_mul_ps(ymm15, ymm0); + + // multiply C by beta and accumulate col 1. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ + _mm256_storeu_ps(tC, ymm4); + _mm256_storeu_ps(tC + 8, ymm5); + _mm256_storeu_ps(tC + 16, ymm6); + _mm256_storeu_ps(tC + 24, ymm7); + + // multiply C by beta and accumulate, col 2. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/ + _mm256_storeu_ps(tC, ymm8); + _mm256_storeu_ps(tC + 8, ymm9); + _mm256_storeu_ps(tC + 16, ymm10); + _mm256_storeu_ps(tC + 24, ymm11); + + // multiply C by beta and accumulate, col 3. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ + _mm256_storeu_ps(tC, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + _mm256_storeu_ps(tC + 24, ymm15); + + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = A_pack; + row_idx_packed = 0; + lda_packed = MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#if 0//def BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 16), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0); +#endif + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing MR x K + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14); + + ymm3 = _mm256_loadu_ps(tA + 24); + // ymm7 += ymm0 * ymm3; + ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); + // ymm11 += ymm1 * ymm3; + ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); + // ymm15 += ymm2 * ymm3; + ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm6 = _mm256_mul_ps(ymm6, ymm0); + ymm7 = _mm256_mul_ps(ymm7, ymm0); + ymm8 = _mm256_mul_ps(ymm8, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + ymm10 = _mm256_mul_ps(ymm10, ymm0); + ymm11 = _mm256_mul_ps(ymm11, ymm0); + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + ymm15 = _mm256_mul_ps(ymm15, ymm0); + + // multiply C by beta and accumulate col 1. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ + _mm256_storeu_ps(tC, ymm4); + _mm256_storeu_ps(tC + 8, ymm5); + _mm256_storeu_ps(tC + 16, ymm6); + _mm256_storeu_ps(tC + 24, ymm7); + + // multiply C by beta and accumulate, col 2. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/ + _mm256_storeu_ps(tC, ymm8); + _mm256_storeu_ps(tC + 8, ymm9); + _mm256_storeu_ps(tC + 16, ymm10); + _mm256_storeu_ps(tC + 24, ymm11); + + // multiply C by beta and accumulate, col 3. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ + _mm256_storeu_ps(tC, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + _mm256_storeu_ps(tC + 24, ymm15); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); + + ymm3 = _mm256_loadu_ps(tA + 24); + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11); + ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_ps(ymm8, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + ymm10 = _mm256_mul_ps(ymm10, ymm0); + ymm11 = _mm256_mul_ps(ymm11, ymm0); + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + ymm15 = _mm256_mul_ps(ymm15, ymm0); + + // multiply C by beta and accumulate, col 1. + /*ymm2 = _mm256_loadu_ps(tC + 0); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/ + _mm256_storeu_ps(tC + 0, ymm8); + _mm256_storeu_ps(tC + 8, ymm9); + _mm256_storeu_ps(tC + 16, ymm10); + _mm256_storeu_ps(tC + 24, ymm11); + + // multiply C by beta and accumulate, col 2. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ + _mm256_storeu_ps(tC, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + _mm256_storeu_ps(tC + 24, ymm15); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14); + + ymm3 = _mm256_loadu_ps(tA + 24); + ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + ymm15 = _mm256_mul_ps(ymm15, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC + 0); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ + + _mm256_storeu_ps(tC + 0, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + _mm256_storeu_ps(tC + 24, ymm15); + } + } + + m_remainder = M - row_idx; + + if (m_remainder >= 24) + { + m_remainder -= 24; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm6 = _mm256_mul_ps(ymm6, ymm0); + ymm8 = _mm256_mul_ps(ymm8, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + ymm10 = _mm256_mul_ps(ymm10, ymm0); + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/ + _mm256_storeu_ps(tC, ymm4); + _mm256_storeu_ps(tC + 8, ymm5); + _mm256_storeu_ps(tC + 16, ymm6); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/ + _mm256_storeu_ps(tC, ymm8); + _mm256_storeu_ps(tC + 8, ymm9); + _mm256_storeu_ps(tC + 16, ymm10); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/ + _mm256_storeu_ps(tC, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_ps(ymm8, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + ymm10 = _mm256_mul_ps(ymm10, ymm0); + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC + 0); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/ + _mm256_storeu_ps(tC + 0, ymm8); + _mm256_storeu_ps(tC + 8, ymm9); + _mm256_storeu_ps(tC + 16, ymm10); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/ + _mm256_storeu_ps(tC, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_ps(tA + 16); + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_ps(ymm12, ymm0); + ymm13 = _mm256_mul_ps(ymm13, ymm0); + ymm14 = _mm256_mul_ps(ymm14, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC + 0); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/ + + _mm256_storeu_ps(tC + 0, ymm12); + _mm256_storeu_ps(tC + 8, ymm13); + _mm256_storeu_ps(tC + 16, ymm14); + } + + row_idx += 24; + } + + if (m_remainder >= 16) + { + m_remainder -= 16; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6); + ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm6 = _mm256_mul_ps(ymm6, ymm0); + ymm7 = _mm256_mul_ps(ymm7, ymm0); + ymm8 = _mm256_mul_ps(ymm8, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(tC, ymm4); + _mm256_storeu_ps(tC + 8, ymm5); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ + _mm256_storeu_ps(tC, ymm6); + _mm256_storeu_ps(tC + 8, ymm7); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/ + _mm256_storeu_ps(tC, ymm8); + _mm256_storeu_ps(tC + 8, ymm9); + + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm6 = _mm256_mul_ps(ymm6, ymm0); + ymm7 = _mm256_mul_ps(ymm7, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(tC, ymm4); + _mm256_storeu_ps(tC + 8, ymm5); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ + _mm256_storeu_ps(tC, ymm6); + _mm256_storeu_ps(tC + 8, ymm7); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + + ymm3 = _mm256_loadu_ps(tA + 8); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(tC, ymm4); + _mm256_storeu_ps(tC + 8, ymm5); + + } + + row_idx += 16; + } + + if (m_remainder >= 8) + { + m_remainder -= 8; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm6 = _mm256_mul_ps(ymm6, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/ + _mm256_storeu_ps(tC, ymm4); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(tC, ymm5); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/ + _mm256_storeu_ps(tC, ymm6); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_ps(ymm4, ymm0); + ymm5 = _mm256_mul_ps(ymm5, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/ + _mm256_storeu_ps(tC, ymm4); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_ps(tC); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(tC, ymm5); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_ps(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + ymm4 = _mm256_mul_ps(ymm4, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/ + _mm256_storeu_ps(tC, ymm4); + + } + + row_idx += 8; + } + // M is not a multiple of 32. + // The handling of edge case where the remainder + // dimension is less than 8. The padding takes place + // to handle this case. + if ((m_remainder) && (lda > 7)) + { + float f_temp[8]; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm5 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + + for (k = 0; k < (K - 1); ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_ps(tA); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); + tB += tb_inc_row; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tA[i]; + } + ymm3 = _mm256_loadu_ps(f_temp); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); + + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + //multiply A*B by alpha. + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm7 = _mm256_mul_ps(ymm7, ymm0); + ymm9 = _mm256_mul_ps(ymm9, ymm0); + + + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_ps(f_temp); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_ps(f_temp); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ + _mm256_storeu_ps(f_temp, ymm7); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_ps(f_temp); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/ + _mm256_storeu_ps(f_temp, ymm9); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm5 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + + for (k = 0; k < (K - 1); ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + tB += tb_inc_row; + + ymm3 = _mm256_loadu_ps(tA); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); + + tA += lda; + } + + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); + tB += tb_inc_row; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tA[i]; + } + ymm3 = _mm256_loadu_ps(f_temp); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); + + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + ymm5 = _mm256_mul_ps(ymm5, ymm0); + ymm7 = _mm256_mul_ps(ymm7, ymm0); + + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_ps(f_temp); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_ps(f_temp); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ + _mm256_storeu_ps(f_temp, ymm7); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm5 = _mm256_setzero_ps(); + + for (k = 0; k < (K - 1); ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + tB += tb_inc_row; + + ymm3 = _mm256_loadu_ps(tA); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + + tA += lda; + } + + ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); + tB += tb_inc_row; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tA[i]; + } + ymm3 = _mm256_loadu_ps(f_temp); + ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); + + ymm0 = _mm256_broadcast_ss(alpha_cast); + //ymm1 = _mm256_broadcast_ss(beta_cast); + + // multiply C by beta and accumulate. + ymm5 = _mm256_mul_ps(ymm5, ymm0); + + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_ps(f_temp); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ + _mm256_storeu_ps(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + } + m_remainder = 0; + } + + if (m_remainder) + { + float result; + for (; row_idx < M; row_idx += 1) + { + for (col_idx = 0; col_idx < N; col_idx += 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + result = 0; + for (k = 0; k < K; ++k) + { + result += (*tA) * (*tB); + tA += lda; + tB += tb_inc_row; + } + + result *= (*alpha_cast); + (*tC) = /*(*tC) * (*beta_cast) + */result; + } + } + } + + //copy/compute sryk values back to C using SIMD + if ( bli_seq0( *beta_cast ) ) + {//just copy in case of beta = 0 + dim_t _i, _j, k, _l; + if(bli_obj_is_lower(c)) // c is lower + { + //first column + _j = 0; + k = M >> 3; + _i = 0; + for ( _l = 0; _l < k; _l++ ) + { + ymm0 = _mm256_loadu_ps((C + _i*rsc)); + _mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0); + _i += 8; + } + while (_i < M ) + { + bli_sscopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + while ( _j < N ) //next column + { + //k = (_j + (8 - (_j & 7))); + _l = _j & 7; + k = (_l != 0) ? (_j + (8 - _l)) : _j; + k = (k <= M) ? k : M; + for ( _i = _j; _i < k; ++_i ) + { + bli_sscopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + k = (M - _i) >> 3; + _l = 0; + while ( _l < k ) + { + ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); + _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + + _i += 8; + _l++; + } + while (_i < M ) + { + bli_sscopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + { + k = (_j + 1) >> 3; + _i = 0; + _l = 0; + while ( _l < k ) + { + ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); + _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + _i += 8; + _l++; + } + while (_i <= _j ) + { + bli_sscopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + ++_i; + } + } + } + } + else + {//when beta is non-zero, fmadd and store the results + dim_t _i, _j, k, _l; + ymm1 = _mm256_broadcast_ss(beta_cast); + if(bli_obj_is_lower(c)) //c is lower + { + //first column + _j = 0; + k = M >> 3; + _i = 0; + for ( _l = 0; _l < k; _l++ ) + { + ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC)); + ymm0 = _mm256_loadu_ps((C + _i*rsc)); + ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0); + _mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0); + _i += 8; + } + while (_i < M ) + { + bli_sssxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + while ( _j < N ) //next column + { + //k = (_j + (8 - (_j & 7))); + _l = _j & 7; + k = (_l != 0) ? (_j + (8 - _l)) : _j; + k = (k <= M) ? k : M; + for ( _i = _j; _i < k; ++_i ) + { + bli_sssxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + k = (M - _i) >> 3; + _l = 0; + while ( _l < k ) + { + ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC)); + ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); + ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0); + _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + + _i += 8; + _l++; + } + while (_i < M ) + { + bli_sssxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + { + k = (_j + 1) >> 3; + _i = 0; + _l = 0; + while ( _l < k ) + { + ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC)); + ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); + ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0); + _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + _i += 8; + _l++; + } + while (_i <= _j ) + { + bli_sssxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + ++_i; + } + } + } + } + + return BLIS_SUCCESS; + } + else + return BLIS_NONCONFORMAL_DIMENSIONS; + + +}; + +static err_t bli_dgemmt_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + + int M = bli_obj_length( c ); // number of rows of Matrix C + int N = bli_obj_width( c ); // number of columns of Matrix C + int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + int L = M * N; + + // If alpha is zero, scale by beta and return. + if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) + || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) + { + + int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C + int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C + int row_idx, col_idx, k; + int rs_matC = bli_obj_row_stride( c ); + int rsc = 1; + double *A = a->buffer; // pointer to elements of Matrix A + double *B = b->buffer; // pointer to elements of Matrix B + double *C = D_C_pack; // pointer to elements of Matrix C + double *matCbuf = c->buffer; + + double *tA = A, *tB = B, *tC = C;//, *tA_pack; + double *tA_packed; // temprorary pointer to hold packed A memory pointer + int row_idx_packed; //packed A memory row index + int lda_packed; //lda of packed A + int col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm0, ymm1, ymm2, ymm3; + + int n_remainder; // If the N is non multiple of 3.(N%3) + int m_remainder; // If the M is non multiple of 16.(M%16) + + double *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = (alpha->buffer); + beta_cast = (beta->buffer); + int required_packing_A = 1; + + // when N is equal to 1 call GEMV instead of GEMMT + if (N == 1) + { + bli_gemv + ( + alpha, + a, + b, + beta, + c + ); + return BLIS_SUCCESS; + } + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM)) + { + required_packing_A = 0; + } + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR) + { + + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + // This is the part of the pack and compute optimization. + // During the first column iteration, we store the accessed A matrix into + // contiguous static memory. This helps to keep te A matrix in Cache and + // aviods the TLB misses. + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = D_A_pack; + +#if 0//def BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + _mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + _mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + _mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + _mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A + // ymm7 += ymm0 * ymm3; + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + // ymm11 += ymm1 * ymm3; + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + // ymm15 += ymm2 * ymm3; + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); + + tA += lda; + tA_packed += D_MR; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + // multiply C by beta and accumulate col 1. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + _mm256_storeu_pd(tC + 12, ymm7); + + // multiply C by beta and accumulate, col 2. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/ + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + _mm256_storeu_pd(tC + 12, ymm11); + + // multiply C by beta and accumulate, col 3. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = D_MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#if 0//def BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + // ymm7 += ymm0 * ymm3; + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + // ymm11 += ymm1 * ymm3; + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + // ymm15 += ymm2 * ymm3; + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + // multiply C by beta and accumulate col 1. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + _mm256_storeu_pd(tC + 12, ymm7); + + // multiply C by beta and accumulate, col 2. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/ + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + _mm256_storeu_pd(tC + 12, ymm11); + + // multiply C by beta and accumulate, col 3. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + // multiply C by beta and accumulate, col 1. + /*ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/ + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + _mm256_storeu_pd(tC + 12, ymm11); + + // multiply C by beta and accumulate, col 2. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ + + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + } + } + + m_remainder = M - row_idx; + + if (m_remainder >= 12) + { + m_remainder -= 12; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/ + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/ + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/ + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/ + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/ + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/ + + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + } + + row_idx += 12; + } + + if (m_remainder >= 8) + { + m_remainder -= 8; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); + ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ + _mm256_storeu_pd(tC, ymm6); + _mm256_storeu_pd(tC + 4, ymm7); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/ + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ + _mm256_storeu_pd(tC, ymm6); + _mm256_storeu_pd(tC + 4, ymm7); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + } + + row_idx += 8; + } + + if (m_remainder >= 4) + { + m_remainder -= 4; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/ + _mm256_storeu_pd(tC, ymm4); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(tC, ymm5); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/ + _mm256_storeu_pd(tC, ymm6); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/ + _mm256_storeu_pd(tC, ymm4); + + // multiply C by beta and accumulate. + tC += ldc; + /*ymm2 = _mm256_loadu_pd(tC); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(tC, ymm5); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm4 = _mm256_mul_pd(ymm4, ymm0); + + // multiply C by beta and accumulate. + /*ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/ + _mm256_storeu_pd(tC, ymm4); + + } + + row_idx += 4; + } + // M is not a multiple of 32. + // The handling of edge case where the remainder + // dimension is less than 8. The padding takes place + // to handle this case. + if ((m_remainder) && (lda > 3)) + { + double f_temp[8]; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + for (k = 0; k < (K - 1); ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tA[i]; + } + ymm3 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + + + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ + _mm256_storeu_pd(f_temp, ymm7); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/ + _mm256_storeu_pd(f_temp, ymm9); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for (k = 0; k < (K - 1); ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + ymm3 = _mm256_loadu_pd(tA); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tA += lda; + } + + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tA[i]; + } + ymm3 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + + tC += ldc; + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ + _mm256_storeu_pd(f_temp, ymm7); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm5 = _mm256_setzero_pd(); + + for (k = 0; k < (K - 1); ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + ymm3 = _mm256_loadu_pd(tA); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + tA += lda; + } + + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tA[i]; + } + ymm3 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_broadcast_sd(alpha_cast); + //ymm1 = _mm256_broadcast_sd(beta_cast); + + // multiply C by beta and accumulate. + ymm5 = _mm256_mul_pd(ymm5, ymm0); + + /*for (int i = 0; i < m_remainder; i++) + { + f_temp[i] = tC[i]; + } + ymm2 = _mm256_loadu_pd(f_temp); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ + _mm256_storeu_pd(f_temp, ymm5); + for (int i = 0; i < m_remainder; i++) + { + tC[i] = f_temp[i]; + } + } + m_remainder = 0; + } + + if (m_remainder) + { + double result; + for (; row_idx < M; row_idx += 1) + { + for (col_idx = 0; col_idx < N; col_idx += 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + result = 0; + for (k = 0; k < K; ++k) + { + result += (*tA) * (*tB); + tA += lda; + tB += tb_inc_row; + } + + result *= (*alpha_cast); + (*tC) = /*(*tC) * (*beta_cast) + */result; + } + } + } + + //copy/compute sryk values back to C using SIMD + if ( bli_seq0( *beta_cast ) ) + {//just copy for beta = 0 + dim_t _i, _j, k, _l; + if(bli_obj_is_lower(c)) //c is lower + { + //first column + _j = 0; + k = M >> 2; + _i = 0; + for ( _l = 0; _l < k; _l++ ) + { + ymm0 = _mm256_loadu_pd((C + _i*rsc)); + _mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0); + _i += 4; + } + while (_i < M ) + { + bli_ddcopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + while ( _j < N ) //next column + { + //k = (_j + (4 - (_j & 3))); + _l = _j & 3; + k = (_l != 0) ? (_j + (4 - _l)) : _j; + k = (k <= M) ? k : M; + for ( _i = _j; _i < k; ++_i ) + { + bli_ddcopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + k = (M - _i) >> 2; + _l = 0; + while ( _l < k ) + { + ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); + _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + + _i += 4; + _l++; + } + while (_i < M ) + { + bli_ddcopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + { + k = (_j + 1) >> 2; + _i = 0; + _l = 0; + while ( _l < k ) + { + ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); + _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + _i += 4; + _l++; + } + while (_i <= _j ) + { + bli_ddcopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + ++_i; + } + } + } + } + else + {//when beta is non-zero, fmadd and store the results + dim_t _i, _j, k, _l; + ymm1 = _mm256_broadcast_sd(beta_cast); + if(bli_obj_is_lower(c)) //c is lower + { + //first column + _j = 0; + k = M >> 2; + _i = 0; + for ( _l = 0; _l < k; _l++ ) + { + ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC)); + ymm0 = _mm256_loadu_pd((C + _i*rsc)); + ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0); + _mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0); + _i += 4; + } + while (_i < M ) + { + bli_dddxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + while ( _j < N ) //next column + { + //k = (_j + (4 - (_j & 3))); + _l = _j & 3; + k = (_l != 0) ? (_j + (4 - _l)) : _j; + k = (k <= M) ? k : M; + for ( _i = _j; _i < k; ++_i ) + { + bli_dddxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + k = (M - _i) >> 2; + _l = 0; + while ( _l < k ) + { + ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC)); + ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); + ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0); + _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + + _i += 4; + _l++; + } + while (_i < M ) + { + bli_dddxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + _i++; + } + _j++; + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + { + k = (_j + 1) >> 2; + _i = 0; + _l = 0; + while ( _l < k ) + { + ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC)); + ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); + ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0); + _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); + _i += 4; + _l++; + } + while (_i <= _j ) + { + bli_dddxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + ++_i; + } + } + } + } + + return BLIS_SUCCESS; + } + else + return BLIS_NONCONFORMAL_DIMENSIONS; + + +}; + +static err_t bli_sgemmt_small_atbn + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + int M = bli_obj_length(c); // number of rows of Matrix C + int N = bli_obj_width(c); // number of columns of Matrix C + int K = bli_obj_length(b); // number of rows of Matrix B + int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C + int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C + int row_idx = 0, col_idx = 0, k; + int rs_matC = bli_obj_row_stride( c ); + int rsc = 1; + float *A = a->buffer; // pointer to matrix A elements, stored in row major format + float *B = b->buffer; // pointer to matrix B elements, stored in column major format + float *C = C_pack; // pointer to matrix C elements, stored in column major format + float *matCbuf = c->buffer; + + float *tA = A, *tB = B, *tC = C; + + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm0, ymm1, ymm2, ymm3; + + float result, scratch[8]; + float *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = (alpha->buffer); + beta_cast = (beta->buffer); + + // The non-copy version of the A^T GEMMT gives better performance for the small M cases. + // The threshold is controlled by BLIS_ATBN_M_THRES + if (M <= BLIS_ATBN_M_THRES) + { + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + //The inner loop computes the 4x3 values of the matrix. + //The computation pattern is: + // ymm4 ymm5 ymm6 + // ymm7 ymm8 ymm9 + // ymm10 ymm11 ymm12 + // ymm13 ymm14 ymm15 + + //The Dot operation is performed in the inner loop, 8 float elements fit + //in the YMM register hence loop count incremented by 8 + for (k = 0; (k + 7) < K; k += 8) + { + ymm0 = _mm256_loadu_ps(tB + 0); + ymm1 = _mm256_loadu_ps(tB + ldb); + ymm2 = _mm256_loadu_ps(tB + 2 * ldb); + + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6); + + ymm3 = _mm256_loadu_ps(tA + lda); + ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); + ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); + + ymm3 = _mm256_loadu_ps(tA + 2 * lda); + ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); + ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); + ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_ps(tA + 3 * lda); + ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); + ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); + ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); + + tA += 8; + tB += 8; + + } + + // if K is not a multiple of 8, padding is done before load using temproary array. + if (k < K) + { + int iter; + float data_feeder[8] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_ps(data_feeder); + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb]; + ymm1 = _mm256_loadu_ps(data_feeder); + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb]; + ymm2 = _mm256_loadu_ps(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); + ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); + ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); + ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); + ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); + ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); + + } + + //horizontal addition and storage of the data. + //Results for 4x3 blocks of C is stored here + ymm4 = _mm256_hadd_ps(ymm4, ymm4); + ymm4 = _mm256_hadd_ps(ymm4, ymm4); + _mm256_storeu_ps(scratch, ymm4); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm7 = _mm256_hadd_ps(ymm7, ymm7); + ymm7 = _mm256_hadd_ps(ymm7, ymm7); + _mm256_storeu_ps(scratch, ymm7); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm10 = _mm256_hadd_ps(ymm10, ymm10); + ymm10 = _mm256_hadd_ps(ymm10, ymm10); + _mm256_storeu_ps(scratch, ymm10); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm13 = _mm256_hadd_ps(ymm13, ymm13); + ymm13 = _mm256_hadd_ps(ymm13, ymm13); + _mm256_storeu_ps(scratch, ymm13); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + + tC += ldc; + ymm5 = _mm256_hadd_ps(ymm5, ymm5); + ymm5 = _mm256_hadd_ps(ymm5, ymm5); + _mm256_storeu_ps(scratch, ymm5); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm8 = _mm256_hadd_ps(ymm8, ymm8); + ymm8 = _mm256_hadd_ps(ymm8, ymm8); + _mm256_storeu_ps(scratch, ymm8); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm11 = _mm256_hadd_ps(ymm11, ymm11); + ymm11 = _mm256_hadd_ps(ymm11, ymm11); + _mm256_storeu_ps(scratch, ymm11); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm14 = _mm256_hadd_ps(ymm14, ymm14); + ymm14 = _mm256_hadd_ps(ymm14, ymm14); + _mm256_storeu_ps(scratch, ymm14); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + + tC += ldc; + ymm6 = _mm256_hadd_ps(ymm6, ymm6); + ymm6 = _mm256_hadd_ps(ymm6, ymm6); + _mm256_storeu_ps(scratch, ymm6); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm9 = _mm256_hadd_ps(ymm9, ymm9); + ymm9 = _mm256_hadd_ps(ymm9, ymm9); + _mm256_storeu_ps(scratch, ymm9); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm12 = _mm256_hadd_ps(ymm12, ymm12); + ymm12 = _mm256_hadd_ps(ymm12, ymm12); + _mm256_storeu_ps(scratch, ymm12); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm15 = _mm256_hadd_ps(ymm15, ymm15); + ymm15 = _mm256_hadd_ps(ymm15, ymm15); + _mm256_storeu_ps(scratch, ymm15); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + } + } + + int processed_col = col_idx; + int processed_row = row_idx; + + //The edge case handling where N is not a multiple of 3 + if (processed_col < N) + { + for (col_idx = processed_col; col_idx < N; col_idx += 1) + { + for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + + //The inner loop computes the 4x1 values of the matrix. + //The computation pattern is: + // ymm4 + // ymm7 + // ymm10 + // ymm13 + + for (k = 0; (k + 7) < K; k += 8) + { + ymm0 = _mm256_loadu_ps(tB + 0); + + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + + ymm3 = _mm256_loadu_ps(tA + lda); + ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); + + ymm3 = _mm256_loadu_ps(tA + 2 * lda); + ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); + + ymm3 = _mm256_loadu_ps(tA + 3 * lda); + ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); + + tA += 8; + tB += 8; + } + + // if K is not a multiple of 8, padding is done before load using temproary array. + if (k < K) + { + int iter; + float data_feeder[8] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_ps(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); + + } + + //horizontal addition and storage of the data. + //Results for 4x1 blocks of C is stored here + ymm4 = _mm256_hadd_ps(ymm4, ymm4); + ymm4 = _mm256_hadd_ps(ymm4, ymm4); + _mm256_storeu_ps(scratch, ymm4); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm7 = _mm256_hadd_ps(ymm7, ymm7); + ymm7 = _mm256_hadd_ps(ymm7, ymm7); + _mm256_storeu_ps(scratch, ymm7); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm10 = _mm256_hadd_ps(ymm10, ymm10); + ymm10 = _mm256_hadd_ps(ymm10, ymm10); + _mm256_storeu_ps(scratch, ymm10); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm13 = _mm256_hadd_ps(ymm13, ymm13); + ymm13 = _mm256_hadd_ps(ymm13, ymm13); + _mm256_storeu_ps(scratch, ymm13); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + + } + } + processed_row = row_idx; + } + + //The edge case handling where M is not a multiple of 4 + if (processed_row < M) + { + for (row_idx = processed_row; row_idx < M; row_idx += 1) + { + for (col_idx = 0; col_idx < N; col_idx += 1) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + + for (k = 0; (k + 7) < K; k += 8) + { + ymm0 = _mm256_loadu_ps(tB + 0); + ymm3 = _mm256_loadu_ps(tA); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + + tA += 8; + tB += 8; + } + + // if K is not a multiple of 8, padding is done before load using temproary array. + if (k < K) + { + int iter; + float data_feeder[8] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_ps(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_ps(data_feeder); + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); + + } + + //horizontal addition and storage of the data. + ymm4 = _mm256_hadd_ps(ymm4, ymm4); + ymm4 = _mm256_hadd_ps(ymm4, ymm4); + _mm256_storeu_ps(scratch, ymm4); + result = scratch[0] + scratch[4]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + } + } + } + + //copy/compute sryk values back to C + if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C + { + dim_t _i, _j; + if(bli_obj_is_lower(c)) //c is lower + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i <= 0 ) + { + bli_sscopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i >= 0 ) + { + bli_sscopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + } + else //when beta is non-zero, multiply and store result to C + { + dim_t _i, _j; + if(bli_obj_is_lower(c)) //c is lower + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i <= 0 ) + { + bli_sssxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i >= 0 ) + { + bli_sssxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + } + + return BLIS_SUCCESS; + } + else + return BLIS_NONCONFORMAL_DIMENSIONS; +} + +static err_t bli_dgemmt_small_atbn + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + int M = bli_obj_length( c ); // number of rows of Matrix C + int N = bli_obj_width( c ); // number of columns of Matrix C + int K = bli_obj_length( b ); // number of rows of Matrix B + int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C + int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C + int row_idx = 0, col_idx = 0, k; + int rs_matC = bli_obj_row_stride( c ); + int rsc = 1; + double *A = a->buffer; // pointer to matrix A elements, stored in row major format + double *B = b->buffer; // pointer to matrix B elements, stored in column major format + double *C = D_C_pack; // pointer to matrix C elements, stored in column major format + double *matCbuf = c->buffer; + + double *tA = A, *tB = B, *tC = C; + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm0, ymm1, ymm2, ymm3; + + double result, scratch[8]; + double *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = (alpha->buffer); + beta_cast = (beta->buffer); + + // The non-copy version of the A^T GEMMT gives better performance for the small M cases. + // The threshold is controlled by BLIS_ATBN_M_THRES + if (M <= BLIS_ATBN_M_THRES) + { + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + //The inner loop computes the 4x3 values of the matrix. + //The computation pattern is: + // ymm4 ymm5 ymm6 + // ymm7 ymm8 ymm9 + // ymm10 ymm11 ymm12 + // ymm13 ymm14 ymm15 + + //The Dot operation is performed in the inner loop, 4 double elements fit + //in the YMM register hence loop count incremented by 4 + for (k = 0; (k + 3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tB + 0); + ymm1 = _mm256_loadu_pd(tB + ldb); + ymm2 = _mm256_loadu_pd(tB + 2 * ldb); + + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); + + ymm3 = _mm256_loadu_pd(tA + lda); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + + ymm3 = _mm256_loadu_pd(tA + 2 * lda); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 3 * lda); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); + + tA += 4; + tB += 4; + + } + + // if K is not a multiple of 4, padding is done before load using temproary array. + if (k < K) + { + int iter; + double data_feeder[4] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_pd(data_feeder); + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb]; + ymm1 = _mm256_loadu_pd(data_feeder); + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb]; + ymm2 = _mm256_loadu_pd(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); + + } + + //horizontal addition and storage of the data. + //Results for 4x3 blocks of C is stored here + ymm4 = _mm256_hadd_pd(ymm4, ymm4); + _mm256_storeu_pd(scratch, ymm4); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm7 = _mm256_hadd_pd(ymm7, ymm7); + _mm256_storeu_pd(scratch, ymm7); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm10 = _mm256_hadd_pd(ymm10, ymm10); + _mm256_storeu_pd(scratch, ymm10); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm13 = _mm256_hadd_pd(ymm13, ymm13); + _mm256_storeu_pd(scratch, ymm13); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + + + tC += ldc; + ymm5 = _mm256_hadd_pd(ymm5, ymm5); + _mm256_storeu_pd(scratch, ymm5); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm8 = _mm256_hadd_pd(ymm8, ymm8); + _mm256_storeu_pd(scratch, ymm8); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm11 = _mm256_hadd_pd(ymm11, ymm11); + _mm256_storeu_pd(scratch, ymm11); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm14 = _mm256_hadd_pd(ymm14, ymm14); + _mm256_storeu_pd(scratch, ymm14); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + + + tC += ldc; + ymm6 = _mm256_hadd_pd(ymm6, ymm6); + _mm256_storeu_pd(scratch, ymm6); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm9 = _mm256_hadd_pd(ymm9, ymm9); + _mm256_storeu_pd(scratch, ymm9); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm12 = _mm256_hadd_pd(ymm12, ymm12); + _mm256_storeu_pd(scratch, ymm12); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm15 = _mm256_hadd_pd(ymm15, ymm15); + _mm256_storeu_pd(scratch, ymm15); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + } + } + + int processed_col = col_idx; + int processed_row = row_idx; + + //The edge case handling where N is not a multiple of 3 + if (processed_col < N) + { + for (col_idx = processed_col; col_idx < N; col_idx += 1) + { + for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + + //The inner loop computes the 4x1 values of the matrix. + //The computation pattern is: + // ymm4 + // ymm7 + // ymm10 + // ymm13 + + for (k = 0; (k + 3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tB + 0); + + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm3 = _mm256_loadu_pd(tA + lda); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm3 = _mm256_loadu_pd(tA + 2 * lda); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + + ymm3 = _mm256_loadu_pd(tA + 3 * lda); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + tA += 4; + tB += 4; + } + // if K is not a multiple of 4, padding is done before load using temproary array. + if (k < K) + { + int iter; + double data_feeder[4] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_pd(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + } + + //horizontal addition and storage of the data. + //Results for 4x1 blocks of C is stored here + ymm4 = _mm256_hadd_pd(ymm4, ymm4); + _mm256_storeu_pd(scratch, ymm4); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + ymm7 = _mm256_hadd_pd(ymm7, ymm7); + _mm256_storeu_pd(scratch, ymm7); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[1] = result/* + tC[1] * (*beta_cast)*/; + + ymm10 = _mm256_hadd_pd(ymm10, ymm10); + _mm256_storeu_pd(scratch, ymm10); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[2] = result/* + tC[2] * (*beta_cast)*/; + + ymm13 = _mm256_hadd_pd(ymm13, ymm13); + _mm256_storeu_pd(scratch, ymm13); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[3] = result/* + tC[3] * (*beta_cast)*/; + + } + } + processed_row = row_idx; + } + + // The edge case handling where M is not a multiple of 4 + if (processed_row < M) + { + for (row_idx = processed_row; row_idx < M; row_idx += 1) + { + for (col_idx = 0; col_idx < N; col_idx += 1) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + + for (k = 0; (k + 3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tB + 0); + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tA += 4; + tB += 4; + } + + // if K is not a multiple of 4, padding is done before load using temproary array. + if (k < K) + { + int iter; + double data_feeder[4] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_pd(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + } + + //horizontal addition and storage of the data. + ymm4 = _mm256_hadd_pd(ymm4, ymm4); + _mm256_storeu_pd(scratch, ymm4); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + tC[0] = result/* + tC[0] * (*beta_cast)*/; + + } + } + } + + //copy/compute sryk values back to C + if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C + { + dim_t _i, _j; + if(bli_obj_is_lower(c)) //c is lower + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i <= 0 ) + { + bli_ddcopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i >= 0 ) + { + bli_ddcopys( *(C + _i*rsc + _j*ldc), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + } + else //when beta is non-zero, multiply and store result to C + { + dim_t _i, _j; + if(bli_obj_is_lower(c)) //c is lower + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i <= 0 ) + { + bli_dddxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + else //c is upper + { + for ( _j = 0; _j < N; ++_j ) + for ( _i = 0; _i < M; ++_i ) + if ( (doff_t)_j - (doff_t)_i >= 0 ) + { + bli_dddxpbys( *(C + _i*rsc + _j*ldc), + *(beta_cast), + *(matCbuf + _i*rs_matC + _j*ldc_matC) ); + } + } + } + + return BLIS_SUCCESS; + } + else + return BLIS_NONCONFORMAL_DIMENSIONS; +} + +#endif + diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c new file mode 100644 index 0000000000..c0b241aa85 --- /dev/null +++ b/kernels/zen/3/bli_trsm_small.c @@ -0,0 +1,27821 @@ +/* + +BLIS +An object-based framework for developing high-performance BLAS-like +libraries. + +Copyright (C) 2018-2019, Advanced Micro Devices, Inc. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +- Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +- Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. +- Neither the name of The University of Texas at Austin nor the names +of its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +#include "immintrin.h" +#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm +#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1' +#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together. +#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle +#define BLI_AlXB_M_SP 16 +#define BLI_XAltB_N_SP 128 +#define BLI_AutXB_M_SP 64 +#define BLI_AutXB_N_SP 128 + +// XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal +static err_t bli_dtrsm_small_XAlB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal +static err_t bli_dtrsm_small_XAlB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal +static err_t bli_dtrsm_small_XAltB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal +static err_t bli_dtrsm_small_XAltB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +// XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal +static err_t bli_dtrsm_small_XAuB + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA = B; A is upper triangular; No transpose; double precision; unit-diagonal +static err_t bli_dtrsm_small_XAuB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal +static err_t bli_dtrsm_small_XAutB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal +static err_t bli_dtrsm_small_XAutB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal +static err_t bli_dtrsm_small_AlXB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//AX = B; A is lower triangular; No transpose; double precision; unit diagonal +static err_t bli_dtrsm_small_AlXB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + + + +static void (*fp_blis_strsm_microkernel)( float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b + ); +static void blis_strsm_microkernel( float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b + ); +static void blis_strsm_microkernel_alpha( float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + float alphaVal + ); +static void blis_strsm_microkernel_unitDiag( float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b + ); +static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + float alphaVal + ); +static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b); +static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + float alphaVal); +static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b); +static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + float alphaVal); + + +static void blis_dtrsm_microkernel( double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b + ); + +static void blis_dtrsm_microkernel_alpha( double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + double alphaVal + ); + +static void blis_dtrsm_microkernel_unitDiag( double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b + ); + +static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + double alphaVal + ); + +static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b); +static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + double alphaVal); +static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b); +static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l, + double *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + double alphaVal); +static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b); +static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + float alpha); +static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b); +static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, + float *ptr_b, + int numRows_lb, + int numCols_b, + int rs_l, + int rs_b, + int cs_l, + int cs_b, + float alpha); + +//AX = B; A is lower triangular; No transpose; single precision +static err_t bli_strsm_small_AlXB + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +//A.'X = B; A is upper triangular; A has to be transposed; single precision +static err_t bli_strsm_small_AutXB + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//XA.' = B; A is lower triangular; A has to be transposed; single precision +static err_t bli_strsm_small_XAltB + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +//A.'X = B; A is upper triangular; A has to be transposed; double precision +static err_t bli_dtrsm_small_AutXB + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +/* +* The bli_trsm_small implements unpacked version of TRSM +* Currently only column-major is supported, A & B are column-major +* Input: A: MxM (triangular matrix) +* B: MxN matrix +* Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B +* Here the output X is stored in B +* The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache +*/ +err_t bli_trsm_small + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ +#ifdef BLIS_ENABLE_MULTITHREADING + return BLIS_NOT_YET_IMPLEMENTED; +#endif + + dim_t m = bli_obj_length(b); + dim_t n = bli_obj_width(b); + + if(!(m && n)) + return BLIS_SUCCESS; + + + // If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix + if (bli_obj_equals(alpha, &BLIS_ZERO)) + { + return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha + } + // We have to call matrix scaling if alpha != 1.0 + + // if row major format return. Check this again. + if ((bli_obj_row_stride(a) != 1) || + (bli_obj_row_stride(b) != 1)) + { + return BLIS_INVALID_ROW_STRIDE; + } + + num_t dt = ((*b).info & (0x7 << 0)); + + // only float and double datatypes are supported as of now. + if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT) + { + return BLIS_EXPECTED_REAL_DATATYPE; + } + + // A is expected to be triangular in trsm + if (!bli_obj_is_upper_or_lower (a)) + { + return BLIS_EXPECTED_TRIANGULAR_OBJECT; + } + + // can use other control structs - even can use array of function pointers, + // indexed by a number with bits formed by f('side', 'uplo', 'transa', dt). + // In the below implementation, based on the number of finally implemented + // cases, can move the checks with more cases higher up. + + if(side == BLIS_LEFT) + { + if(bli_obj_has_trans(a)) + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + //return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + //return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + } + else + { + if(bli_obj_is_upper(a)) + { + return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl); + } + else + { + //return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + + } + } + else + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + //return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + if(bli_obj_has_unit_diag(a)) + return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl); + else + return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl); + } + } + else + { + if(bli_obj_is_upper(a)) + { + //return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl); + } + + } + + } + } + else + { + if(bli_obj_has_trans(a)) + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + if(bli_obj_has_unit_diag(a)) + return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl); + else + return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl); + } + else + { + if(bli_obj_has_unit_diag(a)) + return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl); + else + return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl); + } + } + else + { + if(bli_obj_is_upper(a)) + { + //return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl); + } + + } + } + else + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + if(bli_obj_has_unit_diag(a)) + return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl); + else + return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl); + } + else + { + if(bli_obj_has_unit_diag(a)) + return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl); + else + return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl); + } + } + else + { + if(bli_obj_is_upper(a)) + { + //return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + //return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + + } + + } + } + return BLIS_NOT_YET_IMPLEMENTED; +}; + +/* TRSM scalar code for the case AX = alpha * B + * A is lower-triangular, non-unit-diagonal, no transpose + * Dimensions: A: mxm X: mxn B:mxn + */ + +static err_t dtrsm_small_AlXB ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb + ) +{ + + dim_t i, j, k; + + for (k = 0; k < M; k++) + { + double lkk_inv = 1.0/A[k+k*lda]; + for (j = 0; j < N; j++) + { + B[k + j*ldb] *= lkk_inv; + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + +/* TRSM scalar code for the case AX = alpha * B + * A is lower-triangular, unit-diagonal, no transpose + * Dimensions: A: mxm X: mxn B:mxn + */ + +static err_t dtrsm_small_AlXB_unitDiag ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb + ) +{ + + dim_t i, j, k; + + for (k = 0; k < M; k++) + { + for (j = 0; j < N; j++) + { + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } + } + return BLIS_SUCCESS; +}// end of function + +/* TRSM scalar code for the case XA = alpha * B + * A is upper-triangular, non-unit-diagonal no transpose + * Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAuB ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + for(k = 0; k < N; k++) + { + double lkk_inv = 1.0/A[k+k*lda]; + for(i = 0; i < M; i++) + { + B[i+k*ldb] *= lkk_inv; + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } + + } +return BLIS_SUCCESS; +} + +/* TRSM scalar code for the case XA = alpha * B + * A is lower-triangular, non-unit triangular, no transpose + * Dimensions: X:mxn A:nxn B:mxn + */ + +static err_t dtrsm_small_XAlB ( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + for(j = 0; j < N; j++) + for(i = 0; i < M; i++) + B[i+j*ldb] *= alpha; + + for(k = N;k--;) + { + double lkk_inv = 1.0/A[(k)+(k)*lda]; + for(i = M;i--;) + { + B[(i)+(k)*ldb] *= lkk_inv; + for(j = k;j--;) + { + B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(k)+(j)*lda]; + } + } + } +return BLIS_SUCCESS; +} + +/* TRSM scalar code for the case XA = alpha * B + * A is lower-triangular, unit-diagonal, no transpose + *Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAlB_unitDiag( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + + for(j = 0 ; j < N; j++) + for(i = 0; i < M; i++) + B[i+j*ldb] *= alpha; + double A_k_j; + for(k = N; k--;) + { + for(j = k; j--;) + { + A_k_j = A[(k)+(j)*lda]; + for(i = M; i--;) + { + B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j; + } + } + } + + +return BLIS_SUCCESS; +} + +/* TRSM scalar code for the case XA = alpha * B + *A is upper-triangular, non-unit-diagonal, A is transposed + * Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAutB ( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + + for(j = 0; j < N; j++) + for(i = 0; i < M; i++) + B[i+j*ldb] *=alpha; + + for(k = N; k--;) + { + double lkk_inv = 1.0/A[(k)+(k)*lda]; + for(i = M; i--;) + { + B[(i)+(k)*ldb] *= lkk_inv; + for(j = k; j--;) + { + B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(j)+(k)*lda]; + } + } + } +return BLIS_SUCCESS; +} + +/* TRSM scalar code for the case XA = alpha * B + * A is upper-triangular, unit-diagonal, A has to be transposed + * Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAutB_unitDiag( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + double A_k_j; + + for(j = 0; j< N; j++) + for(i = 0; i< M; i++) + B[i+j*ldb] *= alpha; + + for(k = N; k--;) + { + for(j = k; j--;) + { + A_k_j = A[(j)+(k)*lda]; + for(i = M; i--;) + { + B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j; + + } + } + } +return BLIS_SUCCESS; +} + +/* TRSM scalar code for the case XA = alpha * B + * A is lower-triangular, non-unit-diagonal, A has to be transposed + * Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAltB ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + + for(k = 0; k < N; k++) + { + double lkk_inv = 1.0/A[k+k*lda]; + for(i = 0; i < M; i++) + { + B[i+k*ldb] *= lkk_inv; + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; + } + } + } +return BLIS_SUCCESS; +} + +/* TRSM scalar code for XA = alpha * B + * A is lower-triangular, unit-diagonal, A has to be transposed + * Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAltB_unitDiag( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + + for(k = 0; k < N; k++) + { + for(i = 0; i < M; i++) + { + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; + } + } + } +return BLIS_SUCCESS; +} + +/* TRSM scalar code for the case XA = alpha * B + * A is upper-triangular, unit-diagonal, no transpose + * Dimensions: X:mxn A:nxn B:mxn + */ +static err_t dtrsm_small_XAuB_unitDiag ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb +) +{ + + dim_t i, j, k; + + for(k = 0; k < N; k++) + { + for(i = 0; i < M; i++) + { + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } + } +return BLIS_SUCCESS; +} + +/* TRSM for the case AX = alpha * B, Double precision + * A is lower-triangular, no-transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn + + b01---> + * ***************** + ** * * * * * + * * * * * * * + * * *b01* * * * + * * * * * * * +a10 ****** b11 ***************** + | * * * | * * * * * + | * * * | * * * * * + | *a10*a11* | *b11* * * * + v * * * v * * * * * + *********** ***************** + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + **************** ***************** + a11---> +*/ +static err_t bli_dtrsm_small_AlXB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + + dim_t D_MR = 4; //size of block along 'M' dimpension + dim_t D_NR = 8; //size of block along 'N' dimension + + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) + || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) + || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t m_remainder = m & 3; //number of remainder rows + dim_t n_remainder = n & 7; //number of remainder columns + + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + double *ptr_b01_dup; + + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; + + double ones = 1.0; + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + + + for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension + { + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1] + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); //A11[3][3] + + ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + + ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + //extract a11 + ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] + ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] + + ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] + + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] + ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A110[][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] + + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] + ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /= A11[2][2] + ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] + + ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11);//1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] + + //(ROw2): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] + + ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3] + ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] + } + + if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + + int iter; + + if((j+D_NR) == n) + { + for(iter = 0; iter < m_remainder; iter++) + f_t[iter] = (b11 + cs_b * 7)[iter]; + f_temp = f_t; + } + else + f_temp = (b11 + cs_b * 7); + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code Begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + + b01 += 1; //move to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + + b01 += 1; //move to next row of B01 + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + + b01 += 1; //move to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + + b01 += 1; //move to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] + + if(3 == m_remainder) + { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1] + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2] + + ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm6 = _mm256_unpacklo_pd(ymm3, ymm0); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + + ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + //extract a11 + ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] + + ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] + + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] + + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2] + ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm15 = _mm256_broadcast_sd((double const *)(&ones)); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); + ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); + ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); + ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); + ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); + ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); + ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); + ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); + + } + else if(2 == m_remainder) + { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1] + + ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + + ymm5 = _mm256_blend_pd(ymm5, ymm0, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + //extract a11 + ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + + ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] + + ymm10 = _mm256_broadcast_sd((double const *)&ones); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store + ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); + ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); + ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); + ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); + ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); + ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); + ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); + + } + else if(1 == m_remainder) + { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm0 = _mm256_div_pd(ymm0, ymm1); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + ymm9 = _mm256_broadcast_sd((double const *)(&ones)); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm9); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm9, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm9, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm9, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm9, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm9); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm9, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm9, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); + ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); + ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); + ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); + ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); + ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); + ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); + ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); + } + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) + _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) + + if((j+D_NR) == n) + { + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 7)[iter] = f_t[iter]; + } + } + } + + if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) + { + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 + + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + //extract diag a11 from a + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] + ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + + + //extract diag a22 from a + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] + ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + + //extract diag a33 from a + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) + + } + if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) + + dim_t iter; + + if((j+4) == n) + { + f_temp = f_t; + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 3)[iter]; + } + else + f_temp = (b11 + cs_b * 3); + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for(k = 0; k < k_iter; k++) //looop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 + + + if(3 == m_remainder) + { + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + //extract diag a11 from a + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + + //extract diag a22 from a + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); + ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); + ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); + } + else if( 2 == m_remainder ) + { + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + //extract diag a11 from a + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store + ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); + ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); + ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); + ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); + + } + else if(1 == m_remainder) + { + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + ymm8 = _mm256_broadcast_sd((double const *)(&ones)); + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); + ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); + ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); + } + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) + + if((j+4) == n) + { + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 3)[iter] = f_temp[iter]; + } + } + + n_remainder -= 4; + j += 4; + + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) + { + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + if(3 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + } + else if(2 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + } + else if(1 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + + } + + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + //extract diag a11 from a + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] + ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + + + //extract diag a22 from a + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] + ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + + //extract diag a33 from a + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + } + + } + if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + + k_iter = i / D_MR; //number of times GEMM operations to be performed + + dim_t iter; + if((j+n_remainder) == n) + { + f_temp = f_t; + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; + } + else + f_temp = (b11 + cs_b * (n_remainder -1)); + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previously calculated values /// + + + //load 4x4 block from b11 + if(3 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 + + ///implement TRSM/// + //determine correct values to store + if(3 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + } + else if(2 == m_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); + } + else if(1 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) + } + if(2 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + + ///implement TRSM/// + //determine correct values to store + if(3 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + } + else if(2 == m_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + } + else if(1 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) + } + if(n_remainder == 1) + { + ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + + ///implement TRSM/// + //determine correct values to store + if(3 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + } + else if(2 == m_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + } + else if(1 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + } + _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) + } + + if((j+n_remainder) == n) + { + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; + } + + ///scalar code for trsm without alpha/// + dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + } + } + return BLIS_SUCCESS; +} + +/* TRSM for the case AX = alpha * B, Double precision + * A is lower-triangular, no-transpose, unit diagonal + * dimensions A: mxm X: mxn B: mxn + + b01---> + * ***************** + ** * * * * * + * * * * * * * + * * *b01* * * * + * * * * * * * +a10 ****** b11 ***************** + | * * * | * * * * * + | * * * | * * * * * + | *a10*a11* | *b11* * * * + v * * * v * * * * * + *********** ***************** + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + **************** ***************** + a11---> +*/ + +static err_t bli_dtrsm_small_AlXB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + + dim_t D_MR = 4; //size of block along 'M' dimpension + dim_t D_NR = 8; //size of block along 'N' dimension + + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) + || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) + || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t m_remainder = m & (3); //number of remainder rows + dim_t n_remainder = n & (7); //number of remainder columns + + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + double *ptr_b01_dup; + + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; + + double ones = 1.0; + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + + + for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension + { + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + + b01 += 1; //mobe to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] + ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] + + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] + ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] + + a11 += cs_a; + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] + + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] + ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] + + ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] + + a11 += cs_a; + + //(ROw1): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] + + ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] + } + + if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + + dim_t iter; + + if((j+D_NR) == n) + { + f_temp = f_t; + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 7)[iter]; + } + else + f_temp = (b11 + cs_b * 7); + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code Begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + + b01 += 1; //move to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + + b01 += 1; //move to next row of B01 + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + + b01 += 1; //move to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + + b01 += 1; //move to next row of B + + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] + + if(3 == m_remainder) + { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] + + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] + + a11 += cs_a; + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] + + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm15 = _mm256_broadcast_sd((double const *)(&ones)); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); + ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); + ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); + ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); + ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); + ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); + ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); + ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); + + } + else if(2 == m_remainder) + { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + + ymm10 = _mm256_broadcast_sd((double const *)&ones); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store + ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); + ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); + ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); + ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); + ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); + ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); + ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); + + } + else if(1 == m_remainder) + { + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); + ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); + ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); + ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); + ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); + ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); + ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); + ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); + } + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) + _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) + + if((j+D_NR) == n) + { + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 7)[iter] = f_temp[iter]; + } + } + } + + if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) + { + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 + + ///implement TRSM/// + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] + + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + + //3rd col + a11 += cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] + ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] + ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) + + } + if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) + + dim_t iter; + + if((j+4) == n) + { + f_temp = f_t; + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 3)[iter]; + } + else + f_temp = (b11 + cs_b * 3); + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for(k = 0; k < k_iter; k++) //looop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 + + + if(3 == m_remainder) + { + ///implement TRSM/// + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] + + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); + ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); + ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); + } + else if(2 == m_remainder) + { + ///implement TRSM/// + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store + ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); + ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); + ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); + ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); + + } + else if(1 == m_remainder) + { + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); + ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); + ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); + } + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) + + if((j+4) == n) + { + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 3)[iter] = f_temp[iter]; + } + } + + n_remainder -= 4; + j += 4; + + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) + { + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + if(3 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + } + else if(2 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + } + else if(1 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + + } + + ///implement TRSM/// + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] + + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + + //3rd col + a11 += cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] + ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] + ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + } + + } + if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + + k_iter = i / D_MR; //number of times GEMM operations to be performed + + dim_t iter; + + if((j+n_remainder) == n) + { + f_temp = f_t; + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; + } + else + f_temp = (b11 + cs_b * (n_remainder -1)); + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previously calculated values /// + + + //load 4x4 block from b11 + if(3 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 + + ///implement TRSM/// + //determine correct values to store + if(3 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + } + else if(2 == m_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); + } + else if(1 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) + } + else if(2 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + + ///implement TRSM/// + //determine correct values to store + if(3 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + } + else if(2 == m_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + } + else if(1 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) + } + else if(1 == n_remainder) + { + ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + + ///implement TRSM/// + //determine correct values to store + if(3 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + } + else if(2 == m_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + } + else if(1 == m_remainder) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + } + _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) + } + + if((j+n_remainder) == n) + { + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; + } + ///scalar code for trsm without alpha/// + dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + } + } + return BLIS_SUCCESS; +} + + +/*implements TRSM for the case XA = alpha * B + *A is upper triangular, non-unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + */ + +/* b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + +*/ +static err_t bli_dtrsm_small_XAuB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) + || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double ones = 1.0; + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + { + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] + + //4th col + a11 += cs_a; + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + //(Row3)FMA operations + ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] + + ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] + + ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] + + ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + a01 = L + j*cs_a; //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + + ///load 4x4 block of b11 + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //subtract the calculated GEMM block from current TRSM block + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] + + //4th col + a11 += cs_a; + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + } + else if(2 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1 + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + } + else if(1 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + } + } + } + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = 0; (j+D_NR-1) a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + +*/ +static err_t bli_dtrsm_small_XAuB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) + || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double ones = 1.0; + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + { + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + + //4th col + a11 += cs_a; + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + + //(Row3)FMA operations + ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] + + ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + a01 = L + j*cs_a; //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + + ///load 4x4 block of b11 + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //subtract the calculated GEMM block from current TRSM block + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + } + else if(2 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + } + else if(1 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + } + } + } + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = 0; (j+D_NR-1) a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + +*/ +static err_t bli_dtrsm_small_XAltB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) + || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) + || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) + || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) + || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double ones = 1.0; + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + { + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + //(Row3)FMA operations + ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] + + ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] + + ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] + + ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + a01 = L + j; //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + + ///load 4x4 block of b11 + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //subtract the calculated GEMM block from current TRSM block + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + } + else if(2 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + } + else if(1 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + ///implement TRSM/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm8 = _mm256_mul_pd(ymm8, ymm7); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm7); //B11[4-7][0] /= A11[0][0] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + } + } + } + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = 0; (j+D_NR-1) a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + +*/ +static err_t bli_dtrsm_small_XAltB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) + || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) + || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) + || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) + || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + { + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + + //4th col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + + //(Row3)FMA operations + ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] + + ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + a01 = L + j; //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + + ///load 4x4 block of b11 + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //subtract the calculated GEMM block from current TRSM block + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + + //4th col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + } + else if(2 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + } + else if(1 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + } + } + } + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = 0; (j+D_NR-1) D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) + ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double ones = 1.0; + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + { + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm0); + + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + + ymm10 = _mm256_mul_pd(ymm10, ymm0); + + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + + ymm9 = _mm256_mul_pd(ymm9, ymm0); + + ymm13 = _mm256_mul_pd(ymm13, ymm0); + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + //(Row 1): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); + + ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + + + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + + ///load 4x4 block of b11 + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //subtract the calculated GEMM block from current TRSM block + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] + + //2nd col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + + //3rd col + a11 += 1; + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm0); + + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + ymm10 = _mm256_mul_pd(ymm10, ymm0); + + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + ymm9 = _mm256_mul_pd(ymm9, ymm0); + + ymm13 = _mm256_mul_pd(ymm13, ymm0); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(2 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //3rd col + a11 += 2; + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm0); + + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + ymm10 = _mm256_mul_pd(ymm10, ymm0); + + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(1 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //4th col + a11 += 3; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm7); + + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } + } + if(i<0) + i += D_NR; + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ///GEMM for previous blocks /// + + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); + + ymm1 = _mm256_mul_pd(ymm1, ymm15); + + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + //(Row 1):FMA operations + ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); + + ymm0 = _mm256_mul_pd(ymm0, ymm15); + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + + } + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previous blocks /// + if(3 == n_remainder) + { + ///load 4x4 block of b11 + ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + ymm1 = _mm256_mul_pd(ymm1, ymm15); + + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + else if(2 == n_remainder) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //3rd col + a11 += 2 * cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + else if(1 == n_remainder) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //4th col + a11 += 3 * cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm3 = _mm256_mul_pd(ymm3, ymm14); + + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } + } + m_remainder -= 4; + i -= 4; + } +// if(i < 0) i = 0; + if(m_remainder) ///implementation for remainder rows + { + dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + } + return BLIS_SUCCESS; +} + +/*implements TRSM for the case XA = alpha * B + *A is lower triangular, unit-diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + */ + +/* <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + +*/ +static err_t bli_dtrsm_small_XAlB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) + ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + { + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + + //4th col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + + //(Row 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + + //(Row 1): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); + + ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + + + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + + ///load 4x4 block of b11 + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //subtract the calculated GEMM block from current TRSM block + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //3rd col + a11 += 2; + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + + //4th col + a11 += 1; + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + //(Row 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(2 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + //4th col + a11 += 3; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(1 == n_remainder) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } + } + if(i<0) + i += D_NR; + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ///GEMM for previous blocks /// + + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); + + //(Row 1):FMA operations + ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + + } + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previous blocks /// + if(3 == n_remainder) + { + ///load 4x4 block of b11 + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + else if(2 == n_remainder) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //3rd col + a11 += 2 * cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + else if(1 == n_remainder) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } + } + m_remainder -= 4; + i -= 4; + } +// if(i < 0) i = 0; + if(m_remainder) ///implementation for remainder rows + { + dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + } + return BLIS_SUCCESS; +} + + +/*implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + */ + +/* <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + +*/ +static err_t bli_dtrsm_small_XAutB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) + ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double ones = 1.0; + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + { + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + a11 += cs_a; + + //2nd col + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + + ymm7 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //extract a33 + ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm7); + + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + //extract a22 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + + ymm10 = _mm256_mul_pd(ymm10, ymm7); + + ymm14 = _mm256_mul_pd(ymm14, ymm7); + + //extract a11 + ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + + ymm9 = _mm256_mul_pd(ymm9, ymm7); + + ymm13 = _mm256_mul_pd(ymm13, ymm7); + + //extract A00 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + //(Row 1):FMA operations + ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); + + ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); + + ymm8 = _mm256_mul_pd(ymm8, ymm7); + + ymm12 = _mm256_mul_pd(ymm12, ymm7); + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) + + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] + + a11 += cs_a; + + //2nd col + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm7 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //extract a33 + ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm7); + + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + //extract a22 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + ymm10 = _mm256_mul_pd(ymm10, ymm7); + + ymm14 = _mm256_mul_pd(ymm14, ymm7); + + //extract a11 + ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + ymm9 = _mm256_mul_pd(ymm9, ymm7); + + ymm13 = _mm256_mul_pd(ymm13, ymm7); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(2 == n_remainder) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + a11 += 2 * cs_a; + + //3rd col + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm7 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //extract a33 + ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm7); + + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + //extract a22 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + ymm10 = _mm256_mul_pd(ymm10, ymm7); + + ymm14 = _mm256_mul_pd(ymm14, ymm7); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(1 == n_remainder) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 3 * cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm7 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_div_pd(ymm7, ymm6); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm11 = _mm256_mul_pd(ymm11, ymm0); + + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } + } + if(i<0) + i += D_NR; + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + { + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ///GEMM for previous blocks /// + + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + a11 += cs_a; + + //2nd col + ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); + + ymm1 = _mm256_mul_pd(ymm1, ymm15); + + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + //(Row 1):FMA operations + ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); + + ymm0 = _mm256_mul_pd(ymm0, ymm15); + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + + } + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { + + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + ///GEMM for previous blocks /// + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///load 4x4 block of b11 + if(3 == n_remainder) + { + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] + + a11 += cs_a; + + //2nd col + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + ymm1 = _mm256_mul_pd(ymm1, ymm15); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + else if(2 == n_remainder) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + + a11 += 2 * cs_a; + + //3rd col + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + else if(1 == n_remainder) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 3 * cs_a; + + //4th col + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm3 = _mm256_mul_pd(ymm3, ymm14); + + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } + } + m_remainder -= 4; + i -= 4; + } + if(m_remainder) ///implementation for remainder rows + { + dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + } + return BLIS_SUCCESS; +} + +/*implements TRSM for the case XA = alpha * B + *A is lower triangular, unit-diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + */ + +/* <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + +*/ +static err_t bli_dtrsm_small_XAutB_unitDiag( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t D_MR = 8; //block dimension along the rows + dim_t D_NR = 4; //block dimension along the columns + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + dim_t m_remainder = m & 7; //number of corner rows + dim_t n_remainder = n & 3; //number of corner columns + + dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + +#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) + ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) + ) + return BLIS_NOT_YET_IMPLEMENTED; +#else + if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) + { + return BLIS_NOT_YET_IMPLEMENTED; + } +#endif + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + dim_t cs_b_offset[2]; //pre-calculated strides + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + double *ptr_a01_dup; + + cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; + cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16; + + for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + { + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + + a11 += cs_a; + + //2nd col + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + + //(ROW 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + + //(Row 1):FMA operations + ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); + + ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) + + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //load 8x4 block of B11 + if(3 == n_remainder) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 2 * cs_a; + + //3rd col + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + //(ROW 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(2 == n_remainder) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + + a11 += 3 * cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + else if(1 == n_remainder) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } + } + if(i<0) + i += D_NR; + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) + { + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + { + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ///GEMM for previous blocks /// + + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + a11 += cs_a; + + //2nd col + ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); + + //(Row 1):FMA operations + ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); + + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + + } + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { + + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ///GEMM for previous blocks /// + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///load 4x4 block of b11 + if(3 == n_remainder) + { + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 2 * cs_a; + + //3rd col + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + else if(2 == n_remainder) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + + a11 += 3 * cs_a; + + //4th col + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + else if(1 == n_remainder) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } + } + m_remainder -= 4; + i -= 4; + } + if(m_remainder) ///implementation for remainder rows + { + dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + } + return BLIS_SUCCESS; +} + + +/* + * AX = Alpha*B, Single precision, A: lower triangular + * This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8 + */ + +static err_t bli_strsm_small_AlXB ( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + obj_t alpha, beta; // gemm parameters + obj_t Ga, Gb, Gc; // for GEMM + int m = bli_obj_length(b); // number of rows of matrix B + int n = bli_obj_width(b); // number of columns of matrix B + + int lda = bli_obj_col_stride(a); // column stride of A + int ldb = bli_obj_col_stride(b); // column stride of B + + int rsa = bli_obj_row_stride(a); // row stride of A + int rsb = bli_obj_row_stride(b); // row stride of B + + int i = 0; + int j; + int blk_size = 8; + int isUnitDiag = bli_obj_has_unit_diag(a); + + float alphaVal; + float* restrict L = a->buffer; + float* restrict B = b->buffer; + + if (m != BLI_AlXB_M_SP || (n&7) != 0) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); + + /* Small _GEMM preparation code */ + bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha ); + bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta ); + + /* B = B - A*B */ + bli_setsc( -(1.0), 0.0, &alpha ); + bli_setsc( (1.0), 0.0, &beta ); + + + bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga); + bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb); + bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc); + + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc ); + + //first block of trsm + Gb.buffer = (void*)(B + i); + + //trsm of first 8xn block + if (alphaVal != 1) + { + if (isUnitDiag == 0) + { + blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); + fp_blis_strsm_microkernel = blis_strsm_microkernel; + } + else + { + blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); + fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; + } + bli_setsc( alphaVal, 0.0, &beta ); + } + else + { + if (isUnitDiag == 0) + { + blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + fp_blis_strsm_microkernel = blis_strsm_microkernel; + } + else + { + blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; + } + } + + //gemm update + for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT + { + Ga.buffer = (void*)(L + j + i*lda); + Gc.buffer = (void*)(B + j); + + bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb + } + + //trsm of remaining blocks + for (i = blk_size; i < m; i += blk_size) + { + Gb.buffer = (void*)(B + i); + + fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + + for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT + { + Ga.buffer = (void*)(L + j + i*lda); + Gc.buffer = (void*)(B + j); + + bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb + } + + } // End of for loop - i + + return BLIS_SUCCESS; +} + + + +/* + * XA' = Alpha*B, Single precision, A: lower triangular + * This kernel implementation supports matrices A and B such that + * m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP + */ +static err_t bli_strsm_small_XAltB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + int m = bli_obj_length(a); // number of rows of matrix B + int n = bli_obj_length(b); // number of columns of matrix B + + int lda = bli_obj_col_stride(a); // column stride of A + int ldb = bli_obj_col_stride(b); // column stride of B + + int rsa = bli_obj_row_stride(a); // row stride of A + int rsb = bli_obj_row_stride(b); // row stride of B + + int i = 0; + int isUnitDiag = bli_obj_has_unit_diag(a); + + float alphaVal; + float *L = a->buffer; + float *B = b->buffer; + + if ((m&7) != 0 || (n&7) != 0) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); + + if (alphaVal != 1) + { + if (isUnitDiag == 0) + { + trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); + } + else + { + trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); + } + } + else + { + if (isUnitDiag == 0) + { + trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + } + else + { + trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + } + } + return BLIS_SUCCESS; +} + +/* + * A'X = Alpha*B, Single precision, A: upper triangular + * This kernel implementation supports matrices A and B such that + * m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP + */ +static err_t bli_strsm_small_AutXB( + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken) + int n = bli_obj_width(b); // number of columns of matrix B + + int lda = bli_obj_col_stride(a); // column stride of A + int ldb = bli_obj_col_stride(b); // column stride of B + + int rsa = bli_obj_row_stride(a); // row stride of A + int rsb = bli_obj_row_stride(b); // row stride of B + + int i = 0; + int isUnitDiag = bli_obj_has_unit_diag(a); + + float alphaVal; + float *L = a->buffer; + float *B = b->buffer; + + if ((m&7) != 0 || (n&7) != 0) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); + + if (alphaVal != 1) + { + if (isUnitDiag == 0) + { + trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); + } + else + { + trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); + } + } + else + { + if (isUnitDiag == 0) + { + trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + } + else + { + trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); + } + } + return BLIS_SUCCESS; +} + +///////////////////////////// AX=B /////////////////////////////// +static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) +{ + float ones = 1.0; + int j; + int cs_b_offset[6]; + //int row2, row4, row6; + float *ptr_b_dup; + + //70 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_cols[8]; + __m256 mat_a_cols_rearr[36]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags; + __m256 alphaReg; + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); + alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); + //row2 = (cs_l << 1); + //row4 = (cs_l << 2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); + //row6 = row2 + row4; + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + + //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L + /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ + + //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers + //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. + //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); + //1st col + mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); + mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); + mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); + mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); + mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); + mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); + mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); + mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); + //2nd col + ptr_l += cs_l; + mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //3rd col + ptr_l += cs_l; + mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //4rth col + ptr_l += cs_l; + mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //5th col + ptr_l += cs_l; + mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //6th col + ptr_l += cs_l; + mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //7th col + ptr_l += cs_l; + mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //7th col + ptr_l += cs_l; + mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + numCols_b -= 8; // blk_width = 8 + + //compute reciprocals of L(i,i) and broadcast in registers + mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); + mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); + mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); + mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); + + //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); + //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); + + //reciprocal of diagnol elements + reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); + + //Start loop for cols of B to be processed in size of blk_width + for (j = 0; j < numCols_b; j += 8) + { + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Read next set of B columns + ptr_b += (cs_b + cs_b_offset[5]); + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + + //end loop of cols + } + + //Last block trsm processing + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + + //end loop of cols +} + +static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) +{ + //float ones = 1.0; + int j; + int cs_b_offset[6]; + //int row2, row4, row6; + float *ptr_b_dup; + + //70 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_cols[8]; + __m256 mat_a_cols_rearr[36]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags; + __m256 alphaReg; + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); + alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); + //row2 = (cs_l << 1); + //row4 = (cs_l << 2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); + //row6 = row2 + row4; + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + + //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L + /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ + + //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers + //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. + //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); + //1st col + mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); + mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); + mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); + mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); + mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); + mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); + mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); + mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); + //2nd col + ptr_l += cs_l; + mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //3rd col + ptr_l += cs_l; + mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //4rth col + ptr_l += cs_l; + mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //5th col + ptr_l += cs_l; + mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //6th col + ptr_l += cs_l; + mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //7th col + ptr_l += cs_l; + mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //8th col + //ptr_l += cs_l; + //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + numCols_b -= 8; // blk_width = 8 + + //compute reciprocals of L(i,i) and broadcast in registers + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); + + //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); + //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); + //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); + //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); + //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); + + //reciprocal of diagnol elements + //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); + + //Start loop for cols of B to be processed in size of blk_width + for (j = 0; j < numCols_b; j += 8) + { + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); + + //extract diag a11 from a + //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Read next set of B columns + ptr_b += (cs_b + cs_b_offset[5]); + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + + //end loop of cols + } + + //Last block trsm processing + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); + + //extract diag a11 from a + //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + + //end loop of cols +} + +static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + //float ones = 1.0; + int j; + int cs_b_offset[6]; + //int row2, row4, row6; + float *ptr_b_dup; + + //70 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_cols[8]; + __m256 mat_a_cols_rearr[36]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags; + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); + //row2 = (cs_l << 1); + //row4 = (cs_l << 2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); + //row6 = row2 + row4; + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + + //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L + /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ + + //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers + //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. + //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); + //1st col + mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); + mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); + mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); + mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); + mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); + mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); + mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); + mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); + //2nd col + ptr_l += cs_l; + mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //3rd col + ptr_l += cs_l; + mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //4rth col + ptr_l += cs_l; + mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //5th col + ptr_l += cs_l; + mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //6th col + ptr_l += cs_l; + mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //7th col + ptr_l += cs_l; + mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //8th col + //ptr_l += cs_l; + //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + numCols_b -= 8; // blk_width = 8 + + //compute reciprocals of L(i,i) and broadcast in registers + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); + + //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); + //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); + //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); + //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); + //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); + + //reciprocal of diagnol elements + //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); + + //Start loop for cols of B to be processed in size of blk_width + for (j = 0; j < numCols_b; j += 8) + { + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + //extract diag a11 from a + //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Read next set of B columns + ptr_b += (cs_b + cs_b_offset[5]); + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + //end loop of cols + } + + //Last block trsm processing + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + //extract diag a11 from a + //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + //end loop of cols +} + +static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + float ones = 1.0; + int j; + int cs_b_offset[6]; + //int row2, row4, row6; + float *ptr_b_dup; + + //70 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_cols[8]; + __m256 mat_a_cols_rearr[36]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags; + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); + //row2 = (cs_l << 1); + //row4 = (cs_l << 2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); + //row6 = row2 + row4; + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); + + //reciprocal_diags = _mm256_loadu_ps((float const *)ones); + + //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L + /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); + ptr_l += cs_l; + mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ + + //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers + //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. + //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); + //1st col + mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); + mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); + mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); + mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); + mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); + mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); + mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); + mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); + //2nd col + ptr_l += cs_l; + mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //3rd col + ptr_l += cs_l; + mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //4rth col + ptr_l += cs_l; + mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //5th col + ptr_l += cs_l; + mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //6th col + ptr_l += cs_l; + mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //7th col + ptr_l += cs_l; + mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + //7th col + ptr_l += cs_l; + mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + numCols_b -= 8; // blk_width = 8 + + //compute reciprocals of L(i,i) and broadcast in registers + mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); + mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); + mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); + mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); + + //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); + //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); + + //reciprocal of diagnol elements + reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); + + //Start loop for cols of B to be processed in size of blk_width + for (j = 0; j < numCols_b; j += 8) + { + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Read next set of B columns + ptr_b += (cs_b + cs_b_offset[5]); + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + //end loop of cols + } + + //Last block trsm processing + ptr_b_dup = ptr_b; + + /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ + + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); + mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); +#else + mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); + mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); + mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); + mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); + mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); + mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); + mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); + mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); + mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); + //end loop of cols +} + +#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels +static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags[2]; + + reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + //read diag elems of L 16x16 block + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + reciprocal_diags[1] = reciprocal_diags[0]; + + //pack first 8 diags together + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + + //Read next 8x8 block of A to get diag elements + i3 += cs_l_offset[6]; + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); + + //pack 8 diags of A together + reciprocal_diags[0] = reciprocal_diags[1]; + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + i = i1 + r; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); +#endif + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + ptr_l_dup = ptr_l; + i4 = i2 + r; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + i4 = k >> 3; + ptr_l_dup += cs_l; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + } + i2 += cs_b_offset[6]; + i += cs_l_offset[6]; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + { + i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} + +static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) +{ + float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags[2]; + __m256 alphaReg; + + reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); + alphaReg = _mm256_broadcast_ss((float const *)&alpha); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + //read diag elems of L 16x16 block + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + reciprocal_diags[1] = reciprocal_diags[0]; + + //pack first 8 diags together + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); +#if 0 + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); +#endif + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + + //Read next 8x8 block of A to get diag elements + i3 += cs_l_offset[6]; + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); + + //pack 8 diags of A together + reciprocal_diags[0] = reciprocal_diags[1]; + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + i = i1 + r; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); +#endif + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + ptr_l_dup = ptr_l; + i4 = i2 + r; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + i4 = k >> 3; + ptr_l_dup += cs_l; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + } + i2 += cs_b_offset[6]; + i += cs_l_offset[6]; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + { + i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + + _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} + +static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + //float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags[2]; + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + //(Row0) + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + i3 += cs_l_offset[6]; + + i = 0; + i2 = 0; + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + i = i1 + r; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); +#endif + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + ptr_l_dup = ptr_l; + i4 = i2 + r; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + i4 = k >> 3; + ptr_l_dup += cs_l; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + } + i2 += cs_b_offset[6]; + i += cs_l_offset[6]; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + { + i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): already done +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} + +static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) +{ + //float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags[2]; + __m256 alphaReg; + alphaReg = _mm256_broadcast_ss((float const *)&alpha); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + +#if 0 + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); +#endif + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); + + //(Row0) + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + i3 += cs_l_offset[6]; + + i = 0; + i2 = 0; + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + i = i1 + r; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); +#endif + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + ptr_l_dup = ptr_l; + i4 = i2 + r; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + i4 = k >> 3; + ptr_l_dup += cs_l; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); + ptr_l_dup += cs_l; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + } + i2 += cs_b_offset[6]; + i += cs_l_offset[6]; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + { + i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): already done + +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} +#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1) +static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[16][8]; + __m256 mat_a_cols_rearr[8]; + __m256 mat_a_blk_elems[64]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags[2]; + + reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + //read diag elems of L 16x16 block + mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); + mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); + mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); + mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); + mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); + mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); + mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); + mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + reciprocal_diags[1] = reciprocal_diags[0]; + + //pack first 8 diags together + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) + mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + + //Read next 8x8 block of A to get diag elements + i3 += cs_l_offset[6]; + mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); + mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); + mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); + mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); + mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); + mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); + mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); + mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); + + //pack 8 diags of A together + reciprocal_diags[0] = reciprocal_diags[1]; + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + i = 0; + i2 = 0; + for (k = 0; k < numCols_b; k += 8) + { + i = i1 + k; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + i2++; + } + + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); + + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); + + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); + mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); + mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); + mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); + mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); + + // _mm256_permute2f128_ps() + + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); + mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); + mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); + mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); + mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); + mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); + mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); + mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); + + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); + mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); + mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); + mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); + mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); + mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); + mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); + mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); + + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); + mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); + mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); + mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); + mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); + mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); + mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); + mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); + + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); + mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); + mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); + mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); + mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); + mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); + mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); + mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); + + i += cs_l_offset[6]; + + + for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + + i4 = i2 + k; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + i4 = k >> 3; + + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) + + //end loop of cols + } + i2 += cs_b_offset[6]; + } + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + k = 0; + for (i = 0; i < numCols_b; i+=8) + { + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + + _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + + + } + ///////////////////loop ends ///////////////////// +} + +static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) +{ + float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[16][8]; + __m256 mat_a_cols_rearr[8]; + __m256 mat_a_blk_elems[64]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags[2]; + __m256 alphaReg; + + reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); + alphaReg = _mm256_broadcast_ss((float const *)&alpha); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + //read diag elems of L 16x16 block + mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); + mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); + mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); + mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); + mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); + mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); + mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); + mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + reciprocal_diags[1] = reciprocal_diags[0]; + + //pack first 8 diags together + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); + mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); + mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); + mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); + mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); + mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); + mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); + mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) + mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + + //Read next 8x8 block of A to get diag elements + i3 += cs_l_offset[6]; + mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); + mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); + mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); + mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); + mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); + mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); + mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); + mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); + + //pack 8 diags of A together + reciprocal_diags[0] = reciprocal_diags[1]; + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + i = 0; + i2 = 0; + for (k = 0; k < numCols_b; k += 8) + { + i = i1 + k; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); + mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); + mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); + mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); + mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); + mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); + mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); + mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); + + i2++; + } + + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); + + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); + + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); + mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); + mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); + mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); + mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); + + // _mm256_permute2f128_ps() + + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); + mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); + mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); + mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); + mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); + mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); + mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); + mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); + + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); + mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); + mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); + mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); + mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); + mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); + mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); + mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); + + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); + mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); + mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); + mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); + mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); + mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); + mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); + mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); + + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); + mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); + mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); + mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); + mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); + mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); + mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); + mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); + + i += cs_l_offset[6]; + + + for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + + i4 = i2 + k; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + i4 = k >> 3; + + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) + + //end loop of cols + } + i2 += cs_b_offset[6]; + } + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + k = 0; + for (i = 0; i < numCols_b; i+=8) + { + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + + _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); + k++; + } + + + } + ///////////////////loop ends ///////////////////// +} + +static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + //float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[16][8]; + //__m256 mat_a_cols_rearr[8]; + __m256 mat_a_blk_elems[64]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags[2]; + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + //(Row0) + mat_b_col[0] = mat_b_rearr[0][0]; + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) + mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + i3 += cs_l_offset[6]; + + i = 0; + i2 = 0; + for (k = 0; k < numCols_b; k += 8) + { + i = i1 + k; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + i2++; + } + + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); + + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); + + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); + mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); + mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); + mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); + mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); + + // _mm256_permute2f128_ps() + + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); + mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); + mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); + mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); + mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); + mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); + mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); + mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); + + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); + mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); + mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); + mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); + mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); + mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); + mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); + mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); + + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); + mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); + mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); + mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); + mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); + mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); + mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); + mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); + + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); + mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); + mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); + mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); + mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); + mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); + mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); + mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); + + i += cs_l_offset[6]; + + for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + + i4 = i2 + k; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + i4 = k >> 3; + + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) + + //end loop of cols + } + i2 += cs_b_offset[6]; + } + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + k = 0; + for (i = 0; i < numCols_b; i+=8) + { + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A + + //(Row0): already done + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + + _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + + + } + ///////////////////loop ends ///////////////////// +} + +static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) +{ + //float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[16][8]; + //__m256 mat_a_cols_rearr[8]; + __m256 mat_a_blk_elems[64]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags[2]; + __m256 alphaReg; + alphaReg = _mm256_broadcast_ss((float const *)&alpha); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); + mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); + mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); + mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); + mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); + mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); + mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); + mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); + + //(Row0) + mat_b_col[0] = mat_b_rearr[0][0]; + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) + mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) + mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) + mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) + mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) + mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) + mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); + + //i += cs_b_offset[6]; + //ptr_b_dup += cs_b_offset[6]; + i += 8; + ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += 8; + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += cs_b_offset[6]; + i1 += cs_b_offset[6]; + i3 += cs_l_offset[6]; + + i = 0; + i2 = 0; + for (k = 0; k < numCols_b; k += 8) + { + i = i1 + k; + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); + mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); + mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); + mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); + mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); + mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); + mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); + mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); + + i2++; + } + + i = 0; + i2 = 0; + for (l = 0; l < j; l += 8) // move across m + { + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); + + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); + + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); + mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); + mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); + mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); + mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); + + // _mm256_permute2f128_ps() + + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); + mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); + mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); + mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); + mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); + mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); + mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); + mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); + + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); + mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); + mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); + mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); + mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); + mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); + mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); + mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); + + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); + mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); + mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); + mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); + mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); + mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); + mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); + mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); + + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); + mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); + mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); + mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); + mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); + mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); + mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); + mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); + + i += cs_l_offset[6]; + + for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + { + /////////////////// Partial Lower 8x8 block trsm of B + + i4 = i2 + k; + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + i4 = k >> 3; + + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) + + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) + mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) + mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) + mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) + mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) + mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) + mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) + mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) + + //end loop of cols + } + i2 += cs_b_offset[6]; + } + + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + i += cs_l; + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); + + k = 0; + for (i = 0; i < numCols_b; i+=8) + { + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A + + //(Row0): already done + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) + + //////////////////////////////////////////////////////////////////////////////// + + //Store the computed B columns + + _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + } + + + } + ///////////////////loop ends ///////////////////// +} +#endif //OPT_CACHE_BLOCKING_L1 + +//////////////////////////// AutX=B /////////////////////// +static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags[2]; + + reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + //read diag elems of L 16x16 block + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + reciprocal_diags[1] = reciprocal_diags[0]; + + //pack first 8 diags together + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); +#if 0 + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); +#endif + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); + + i += cs_b_offset[6]; + ptr_b_dup += cs_b_offset[6]; + //i += 8; + //ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += cs_l_offset[6]; + + //Read next 8x8 block of A to get diag elements + i3 += 8; + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); + + //pack 8 diags of A together + reciprocal_diags[0] = reciprocal_diags[1]; + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += 8; + i1 += 8; + i = i1; + i2 = 0; + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ +#endif + //i = 0; + ptr_l_dup = ptr_l; + i4 = i2; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + //{ + /////////////////// Partial Lower 8x8 block trsm of B + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); + mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); +#else + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); + mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); + mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); + mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); + mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); + /* transpose steps end */ + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i4 = k >> 3; + ptr_l_dup++; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + //} + //i2 += cs_b_offset[6]; + i4 += 8; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + //{ + //i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + //} + i += cs_b_offset[6]; + i2 += cs_b_offset[6]; + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} + +static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) +{ + float ones = 1.0; + int i, i1, i2, i3, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + __m256 mat_a_diag_inv[8]; + __m256 reciprocal_diags[2]; + __m256 alphaReg; + + reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); + alphaReg = _mm256_broadcast_ss((float const *)&alpha); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + //read diag elems of L 16x16 block + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + + reciprocal_diags[1] = reciprocal_diags[0]; + + //pack first 8 diags together + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); +#if 0 + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); +#endif + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); + + i += cs_b_offset[6]; + ptr_b_dup += cs_b_offset[6]; + //i += 8; + //ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i3 = 0; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += cs_l_offset[6]; + + //Read next 8x8 block of A to get diag elements + i3 += 8; + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); + + //pack 8 diags of A together + reciprocal_diags[0] = reciprocal_diags[1]; + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 + mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 + mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 + mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 + mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 + + //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 + reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); + + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += 8; + i1 += 8; + i = i1; + i2 = 0; + + //extract diag a00 from a + mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); + //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); + + //extract diag a11 from a + mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); + //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); + + //extract diag a22 from a + mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); + //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); + + //extract diag a33 from a + mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); + //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); + + //extract diag a44 from a + mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); + mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); + //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); + + //extract diag a55 from a + mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); + mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); + //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); + + //extract diag a66 from a + mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); + mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); + //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); + + //extract diag a77 from a + mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); + mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); + //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); + + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); +#endif + + //i = 0; + ptr_l_dup = ptr_l; + i4 = i2; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + //{ + /////////////////// Partial Lower 8x8 block trsm of B + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); + mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); +#else + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); + mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); + mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); + mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); + mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); + /* transpose steps end */ + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i4 = k >> 3; + ptr_l_dup++; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + //} + //i2 += cs_b_offset[6]; + i4 += 8; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + //{ + //i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); + //i += cs_l; + + //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); + + //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + //} + i += cs_b_offset[6]; + i2 += cs_b_offset[6]; + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} + +static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) +{ + //float ones = 1.0; + int i, i1, i2, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags[2]; + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + +#if 0 + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); +#endif + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + + //(Row0) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); + + i += cs_b_offset[6]; + ptr_b_dup += cs_b_offset[6]; + //i += 8; + //ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += cs_l_offset[6]; + + + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += 8; + i1 += 8; + i = i1; + i2 = 0; + + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ +#endif + + //i = 0; + ptr_l_dup = ptr_l; + i4 = i2; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + //{ + /////////////////// Partial Lower 8x8 block trsm of B + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); + mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); +#else + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); + mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); + mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); + mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); + mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); + /* transpose steps end */ + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i4 = k >> 3; + ptr_l_dup++; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + //} + //i2 += cs_b_offset[6]; + i4 += 8; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + //{ + //i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): already done + +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); + + + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + //} + i += cs_b_offset[6]; + i2 += cs_b_offset[6]; + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} + +static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) +{ + //float ones = 1.0; + int i, i1, i2, i4, j, k, l, r; + int cs_b_offset[7]; + int cs_l_offset[7]; + float *ptr_b_dup, *ptr_l_dup; + + //57 number of ymm(256 bits) registers used + __m256 mat_b_col[8]; + __m256 mat_b_rearr[8]; + __m256 mat_a_blk_elems[8]; + //__m256 mat_a_diag_inv[8]; + //__m256 reciprocal_diags[2]; + __m256 alphaReg; + alphaReg = _mm256_broadcast_ss((float const *)&alpha); + + // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // + + //L matrix offsets + cs_l_offset[0] = (cs_l << 1); + cs_l_offset[1] = cs_l + cs_l_offset[0]; + cs_l_offset[2] = (cs_l << 2); + cs_l_offset[3] = cs_l + cs_l_offset[2]; + cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; + cs_l_offset[5] = cs_l + cs_l_offset[4]; + cs_l_offset[6] = (cs_l_offset[5] + cs_l); + + cs_b_offset[0] = (cs_b << 1); + cs_b_offset[1] = cs_b + cs_b_offset[0]; + cs_b_offset[2] = (cs_b << 2); + cs_b_offset[3] = cs_b + cs_b_offset[2]; + cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; + cs_b_offset[5] = cs_b + cs_b_offset[4]; + cs_b_offset[6] = (cs_b_offset[5] + cs_b); + +#if 0 + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); + + //Broadcast A21 to A71 to registers + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); + mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); + mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); + mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); + mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); + mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); + mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); + mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); + mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); + mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); + mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); + mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); + mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); + mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); + mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); + mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); + + //Broadcast A76 to register + mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); +#endif + + + /***************** first set of 8 rows of B processing starts *****************/ + ptr_b_dup = ptr_b; + i = 0; + for (j = 0; j < numCols_b; j += 8) + { + /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A + //read 8x8 block of B into registers + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); + + //(Row0) + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) + + + + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) + + + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); + + i += cs_b_offset[6]; + ptr_b_dup += cs_b_offset[6]; + //i += 8; + //ptr_b_dup += 8; + } + + //c = 0; + /***************** first set of 8 cols of B processing done *****************/ + ptr_b_dup = ptr_b; + i1 = 0; + //Start loop for cols of B to be processed in size of blk_width + for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row + { + ptr_l += cs_l_offset[6]; + + + //ptr_b += j; + //ptr_b_dup += 8; + ptr_b_dup += 8; + i1 += 8; + i = i1; + i2 = 0; + + for (r = 0; r < numCols_b; r += GEMM_BLK_V1) + { +#if GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); + mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); + mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); + mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + + ////unpackhigh//// + mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); + mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); + mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); + mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + /* transpose steps end */ + + mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); + mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); + mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); + mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); + mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); + mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); + mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); + mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); +#endif + + //i = 0; + ptr_l_dup = ptr_l; + i4 = i2; + for (l = 0; l < j; l += 8) // move across m + { + //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) + //{ + /////////////////// Partial Lower 8x8 block trsm of B + //Read current 8 cols of B columns from specified 8x8 current-block of B + mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); + mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); + mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); + mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); + mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); + mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); + mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); + mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); + mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); + mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); + mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); + mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); +#else + mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); + mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); + mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); + mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); + mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); + mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); + /* transpose steps end */ + + //Broadcast A8,0 to A15,0 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i4 = k >> 3; + ptr_l_dup++; + +#if GEMM_ACCUM_A + //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); + mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); + mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); + mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); + mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); + mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); + mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); + mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,2 to A15,2 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,3 to A15,3 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,4 to A15,4 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,5 to A15,5 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,6 to A15,6 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A8,7 to A15,7 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + ptr_l_dup++; +#if GEMM_ACCUM_A + //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) + mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) +#endif + //end loop of cols + //} + //i2 += cs_b_offset[6]; + i4 += 8; + } + //trsm solve + + k = 0; + //for (i2 = 0; i2 < numCols_b; i2 += 8) + //{ + //i2 = i1 + r; + /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A +#if !GEMM_ACCUM_A + //Read 8 cols of B columns of Block-to-be-solved + mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); + mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); + mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); + mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); + mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); + mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); + mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); + mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); + mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); + mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); + mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); + mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); + mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); + mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); + mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); +#endif + //Broadcast A10 to A70 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); + mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); + //i += cs_l; + +#if GEMM_ACCUM_A + //(Row0): already done + +#else + mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); +#endif + +#if GEMM_ACCUM_A + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#else + mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); + mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); + mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); + mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); + mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); + mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); + mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) + mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) +#endif + //Broadcast A21 to A71 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); + mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A32 to A72 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); + mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) + mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A43 to A73 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); + mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) + mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A54 to A74 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); + mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) + mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A65 to A75 to registers + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); + mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); + //i += cs_l; + + + + //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) + mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) + + //Broadcast A76 to register + mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); + + + + //(Row7): FMA operations of b7 with elements of index (7, 0) + mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) + + + + //////////////////////////////////////////////////////////////////////////////// + + /* transpose steps start */ + ////unpacklow//// + mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange low elements +#if REARRANGE_SHFL == 1 + mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); + mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); +#else + mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); + mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); + mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); + mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); + mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); + mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); +#endif + //Merge rearranged low elements into complete rows + mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); + mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); + mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); + mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); + + ////unpackhigh//// + mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); + mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); + mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); + mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); + + //Rearrange high elements +#if REARRANGE_SHFL == 1 + mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); + mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); +#else + mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); + mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); + mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); + mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); + mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); + mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); +#endif + + //Merge rearranged high elements into complete rows + mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); + mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); + mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); + mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); + /* transpose steps end */ + + //Store the computed B columns + _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); + _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); + _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); + //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); + k++; + //} + i += cs_b_offset[6]; + i2 += cs_b_offset[6]; + } + } //numRows of A + ///////////////////loop ends ///////////////////// +} +#endif diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c new file mode 100644 index 0000000000..03c1627f15 --- /dev/null +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c @@ -0,0 +1,2303 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. +// outputs to ymm0 +#define CGEMM_INPUT_SCALE_CS_BETA_NZ \ + vmovlpd(mem(rcx), xmm0, xmm0) \ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ + vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ + vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ + vinsertf128(imm(1), xmm3, ymm0, ymm0) \ + vpermilps(imm(0xb1), ymm0, ymm3) \ + vmulps(ymm1, ymm0, ymm0) \ + vmulps(ymm2, ymm3, ymm3) \ + vaddsubps(ymm3, ymm0, ymm0) + +#define CGEMM_INPUT_SCALE_CS_BETA_NZ_128 \ + vmovlpd(mem(rcx), xmm0, xmm0) \ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ + vpermilps(imm(0xb1), xmm0, xmm3) \ + vmulps(xmm1, xmm0, xmm0) \ + vmulps(xmm2, xmm3, xmm3) \ + vaddsubps(xmm3, xmm0, xmm0) + +// assumes values to output are in ymm0 +#define CGEMM_OUTPUT_GS \ + vextractf128(imm(1), ymm0, xmm3) \ + vmovlpd(xmm0, mem(rcx)) \ + vmovhpd(xmm0, mem(rcx, rsi, 1)) \ + vmovlpd(xmm3, mem(rcx, rsi, 2)) \ + vmovhpd(xmm3, mem(rcx, r13, 1)) + +#define CGEMM_INPUT_SCALE_RS_BETA_NZ \ + vmovups(mem(rcx), ymm0) \ + vpermilps(imm(0xb1), ymm0, ymm3) \ + vmulps(ymm1, ymm0, ymm0) \ + vmulps(ymm2, ymm3, ymm3) \ + vaddsubps(ymm3, ymm0, ymm0) + +#define CGEMM_OUTPUT_RS \ + vmovups(ymm0, mem(rcx)) \ + +#define CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ + vmovups(mem(rcx, rsi, 8), ymm0) \ + vpermilps(imm(0xb1), ymm0, ymm3) \ + vmulps(ymm1, ymm0, ymm0) \ + vmulps(ymm2, ymm3, ymm3) \ + vaddsubps(ymm3, ymm0, ymm0) + +#define CGEMM_OUTPUT_RS_NEXT \ + vmovups(ymm0, mem(rcx, rsi, 8)) \ + + +void bli_cgemmsup_rv_zen_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* data, + cntx_t* cntx + ) +{ + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + vzeroall() // zero all xmm/ymm registers. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11 + vpermilps(imm(0xb1), ymm6, ymm6) + vpermilps(imm(0xb1), ymm7, ymm7) + vpermilps(imm(0xb1), ymm10, ymm10) + vpermilps(imm(0xb1), ymm11, ymm11) + + // subtract/add even/odd elements + vaddsubps(ymm6, ymm4, ymm4) + vaddsubps(ymm7, ymm5, ymm5) + + vaddsubps(ymm10, ymm8, ymm8) + vaddsubps(ymm11, ymm9, ymm9) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), ymm4, ymm3) + vmulps(ymm0, ymm4, ymm4) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm4, ymm4) + + vpermilps(imm(0xb1), ymm5, ymm3) + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm5, ymm5) + + vpermilps(imm(0xb1), ymm8, ymm3) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm8, ymm8) + + vpermilps(imm(0xb1), ymm9, ymm3) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm9, ymm9) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm4, ymm0, ymm0) + CGEMM_OUTPUT_RS + + CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddps(ymm5, ymm0, ymm0) + CGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm8, ymm0, ymm0) + CGEMM_OUTPUT_RS + + CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddps(ymm9, ymm0, ymm0) + CGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + /*|----------------| |-------| + | | | | | + | 2x4 | 2x4 | | 4x2 | + | | | |-------| + |----------------| | | + | 4x2 | + |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm4, ymm0, ymm4) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm8, ymm0, ymm8) + add(rdi, rcx) + + lea(mem(r12, rsi, 4), rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm5, ymm0, ymm5) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm9, ymm0, ymm9) + add(rdi, rcx) + + mov(r12, rcx) // reset rcx to current utile of c. + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-12 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vmovups(xmm0, mem(rcx)) // store (gamma02-12) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma03-13) + lea(mem(rcx, rsi, 1), rcx) + + /******************Transpose bottom tile 4x3***************************/ + vunpcklpd(ymm9, ymm5, ymm0) //a8a9b8b9 a12a13b12b13 //gamma04-14 gamma06-16 + vunpckhpd(ymm9, ymm5, ymm2) //a10a11b10b11 a14a15b14b15 //gamma05-15 gamma07-17 + + vmovups(xmm0, mem(rcx)) // store (gamma04-14) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma05-15) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vmovups(xmm0, mem(rcx)) // store (gamma06-16) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma07-17) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****2x8 tile going to save into 8x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 + + /******************Transpose top tile 4x2***************************/ + vmovups(xmm0, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vmovups(xmm0, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + + /******************Transpose bottom tile 4x2***************************/ + vunpcklpd(ymm9, ymm5, ymm0) //a8a9b8b9 a12a13b12b13 + vunpckhpd(ymm9, ymm5, ymm2) //a10a11b10b11 a14a15b14b15 + + vmovups(xmm0, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vmovups(xmm0, mem(rcx)) + lea(mem(rcx, rsi, 1), rcx) + + vmovups(xmm2, mem(rcx)) + + label(.SDONE) + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_cgemmsup_rv_zen_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* data, + cntx_t* cntx + ) +{ + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + vzeroall() // zero all xmm/ymm registers. + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7 + vpermilps(imm(0xb1), ymm6, ymm6) + vpermilps(imm(0xb1), ymm7, ymm7) + + // subtract/add even/odd elements + vaddsubps(ymm6, ymm4, ymm4) + vaddsubps(ymm7, ymm5, ymm5) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), ymm4, ymm3) + vmulps(ymm0, ymm4, ymm4) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm4, ymm4) + + vpermilps(imm(0xb1), ymm5, ymm3) + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm5, ymm5) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm4, ymm0, ymm0) + CGEMM_OUTPUT_RS + + CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddps(ymm5, ymm0, ymm0) + CGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm4, ymm0, ymm4) + + lea(mem(r12, rsi, 4), rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm5, ymm0, ymm5) + + mov(r12, rcx) // reset rcx to current utile of c. + /******************Transpose top tile 4x1***************************/ + vmovlpd(xmm4, mem(rcx)) // store (gamma40) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma41) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx)) // store (gamma42) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma43) + lea(mem(rcx, rsi, 1), rcx) + + /******************Transpose bottom tile 4x1***************************/ + + vmovlpd(xmm5, mem(rcx)) // store (gamma44) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm5, mem(rcx)) // store (gamma45) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vmovlpd(xmm5, mem(rcx)) // store (gamma46) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm5, mem(rcx)) // store (gamma47) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****1x8 tile going to save into 8x1 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose top tile 4x1***************************/ + vmovlpd(xmm4, mem(rcx)) // store (gamma40) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma41) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx)) // store (gamma42) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma43) + lea(mem(rcx, rsi, 1), rcx) + + /******************Transpose bottom tile 4x1***************************/ + vmovlpd(xmm5, mem(rcx)) // store (gamma44) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm5, mem(rcx)) // store (gamma45) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vmovlpd(xmm5, mem(rcx)) // store (gamma46) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm5, mem(rcx)) // store (gamma47) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_cgemmsup_rv_zen_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* data, + cntx_t* cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7 + vpermilps(imm(0xb1), ymm6, ymm6) + vpermilps(imm(0xb1), ymm10, ymm10) + + // subtract/add even/odd elements + vaddsubps(ymm6, ymm4, ymm4) + vaddsubps(ymm10, ymm8, ymm8) + + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), ymm4, ymm3) + vmulps(ymm0, ymm4, ymm4) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm4, ymm4) + + vpermilps(imm(0xb1), ymm8, ymm3) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm8, ymm8) + + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm4, ymm0, ymm0) + CGEMM_OUTPUT_RS + + add(rdi, rcx) // rcx = c + 1*rs_c + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm8, ymm0, ymm0) + CGEMM_OUTPUT_RS + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm4, ymm0, ymm4) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm8, ymm0, ymm8) + + mov(r12, rcx) // reset rcx to current utile of c. + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-12 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + /******************Transpose top tile 4x2***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vmovups(xmm0, mem(rcx)) // store (gamma02-12) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma03-13) + lea(mem(rcx, rsi, 1), rcx) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****2x4 tile going to save into 4x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-12 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vmovups(xmm0, mem(rcx)) // store (gamma02-12) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma03-13) + + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_cgemmsup_rv_zen_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* data, + cntx_t* cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilps(imm(0xb1), ymm6, ymm6) + + // subtract/add even/odd elements + vaddsubps(ymm6, ymm4, ymm4) + + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), ymm4, ymm3) + vmulps(ymm0, ymm4, ymm4) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm4, ymm4) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm4, ymm0, ymm0) + CGEMM_OUTPUT_RS + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm4, ymm0, ymm4) + + mov(r12, rcx) // reset rcx to current utile of c. + + vmovlpd(xmm4, mem(rcx)) // store (gamma00-10) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma01-11) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx)) // store (gamma02-12) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma03-13) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORBZ) + /****1x4 tile going to save into 4x1 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vmovlpd(xmm4, mem(rcx)) // store (gamma40) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma41) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + + vmovlpd(xmm4, mem(rcx)) // store (gamma42) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma43) + lea(mem(rcx, rsi, 1), rcx) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_cgemmsup_rv_zen_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* data, + cntx_t* cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of xmm6/7 + vpermilps(imm(0xb1), xmm6, xmm6) + vpermilps(imm(0xb1), xmm10, xmm10) + + // subtract/add even/odd elements + vaddsubps(xmm6, xmm4, xmm4) + vaddsubps(xmm10, xmm8, xmm8) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), xmm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), xmm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), xmm4, xmm3) + vmulps(xmm0, xmm4, xmm4) + vmulps(xmm1, xmm3, xmm3) + vaddsubps(xmm3, xmm4, xmm4) + + vpermilps(imm(0xb1), xmm8, xmm3) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm1, xmm3, xmm3) + vaddsubps(xmm3, xmm8, xmm8) + + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), xmm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), xmm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + // now avoid loading C if beta == 0 + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vpermilps(imm(0xb1), xmm0, xmm3) + vmulps(xmm1, xmm0, xmm0) + vmulps(xmm2, xmm3, xmm3) + vaddsubps(xmm3, xmm0, xmm0) + + vaddps(xmm4, xmm0, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + add(rdi, rcx) // rcx = c + 1*rs_c + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vpermilps(imm(0xb1), xmm0, xmm3) + vmulps(xmm1, xmm0, xmm0) + vmulps(xmm2, xmm3, xmm3) + vaddsubps(xmm3, xmm0, xmm0) + + vaddps(xmm8, xmm0, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ_128 + vaddps(xmm4, xmm0, xmm4) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ_128 + vaddps(xmm8, xmm0, xmm8) + + mov(r12, rcx) // reset rcx to current utile of c. + + vunpcklpd(xmm8, xmm4, xmm0) //a0a1b0b1 //gamma00-10 + vunpckhpd(xmm8, xmm4, xmm2) //a2a3b2b3 //gamma01-11 + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****2x2 tile going to save into 4x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vunpcklpd(xmm8, xmm4, xmm0) //a0a1b0b1 //gamma00-10 + vunpckhpd(xmm8, xmm4, xmm2) //a2a3b2b3 //gamma01-11 + + /******************Transpose top tile 2x2***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_cgemmsup_rv_zen_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* data, + cntx_t* cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, 4 ), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, 4 ), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, 4 ), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, 4 ), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, 4 ), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of xmm6 + vpermilps(imm(0xb1), xmm6, xmm6) + + // subtract/add even/odd elements + vaddsubps(xmm6, xmm4, xmm4) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), xmm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), xmm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), xmm4, xmm3) + vmulps(xmm0, xmm4, xmm4) + vmulps(xmm1, xmm3, xmm3) + vaddsubps(xmm3, xmm4, xmm4) + + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), xmm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), xmm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + // now avoid loading C if beta == 0 + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vpermilps(imm(0xb1), xmm0, xmm3) + vmulps(xmm1, xmm0, xmm0) + vmulps(xmm2, xmm3, xmm3) + vaddsubps(xmm3, xmm0, xmm0) + + vaddps(xmm4, xmm0, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ_128 + vaddps(xmm4, xmm0, xmm4) + + mov(r12, rcx) // reset rcx to current utile of c. + + vmovlpd(xmm4, mem(rcx)) // store (gamma40-50) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma41-51) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****1x2 tile going to save into 2x1 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vmovlpd(xmm4, mem(rcx)) // store (gamma40) + lea(mem(rcx, rsi, 1), rcx) + vmovhpd(xmm4, mem(rcx)) // store (gamma41) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c new file mode 100644 index 0000000000..8d10406a05 --- /dev/null +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c @@ -0,0 +1,1756 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. +// outputs to ymm0 +#define CGEMM_INPUT_SCALE_CS_BETA_NZ \ + vmovlpd(mem(rcx), xmm0, xmm0) \ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ + vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ + vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ + vinsertf128(imm(1), xmm3, ymm0, ymm0) \ + vpermilps(imm(0xb1), ymm0, ymm3) \ + vmulps(ymm1, ymm0, ymm0) \ + vmulps(ymm2, ymm3, ymm3) \ + vaddsubps(ymm3, ymm0, ymm0) + +#define CGEMM_INPUT_SCALE_CS_BETA_NZ_128 \ + vmovlpd(mem(rcx), xmm0, xmm0) \ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ + vpermilps(imm(0xb1), xmm0, xmm3) \ + vmulps(xmm1, xmm0, xmm0) \ + vmulps(xmm2, xmm3, xmm3) \ + vaddsubps(xmm3, xmm0, xmm0) + +#define CGEMM_INPUT_SCALE_RS_BETA_NZ \ + vmovups(mem(rcx), ymm0) \ + vpermilps(imm(0xb1), ymm0, ymm3) \ + vmulps(ymm1, ymm0, ymm0) \ + vmulps(ymm2, ymm3, ymm3) \ + vaddsubps(ymm3, ymm0, ymm0) + +#define CGEMM_OUTPUT_RS \ + vmovups(ymm0, mem(rcx)) \ + +#define CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ + vmovups(mem(rcx, rsi, 8), ymm0) \ + vpermilps(imm(0xb1), ymm0, ymm3) \ + vmulps(ymm1, ymm0, ymm0) \ + vmulps(ymm2, ymm3, ymm3) \ + vaddsubps(ymm3, ymm0, ymm0) + +#define CGEMM_OUTPUT_RS_NEXT \ + vmovups(ymm0, mem(rcx, rsi, 8)) \ + +/* + rrr: + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + + rcr: + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : +*/ +void bli_cgemmsup_rv_zen_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 3x?m kernels, as needed. + if (n_left ) + { + scomplex* cij = c; + scomplex* bj = b; + scomplex* ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_3x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_3x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + vbroadcastss(mem(rax, 4 ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilps(imm(0xb1), ymm6, ymm6) + vpermilps(imm(0xb1), ymm7, ymm7) + vpermilps(imm(0xb1), ymm10, ymm10) + vpermilps(imm(0xb1), ymm11, ymm11) + vpermilps(imm(0xb1), ymm14, ymm14) + vpermilps(imm(0xb1), ymm15, ymm15) + + // subtract/add even/odd elements + vaddsubps(ymm6, ymm4, ymm4) + vaddsubps(ymm7, ymm5, ymm5) + + vaddsubps(ymm10, ymm8, ymm8) + vaddsubps(ymm11, ymm9, ymm9) + + vaddsubps(ymm14, ymm12, ymm12) + vaddsubps(ymm15, ymm13, ymm13) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), ymm4, ymm3) + vmulps(ymm0, ymm4, ymm4) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm4, ymm4) + + vpermilps(imm(0xb1), ymm5, ymm3) + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm5, ymm5) + + vpermilps(imm(0xb1), ymm8, ymm3) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm8, ymm8) + + vpermilps(imm(0xb1), ymm9, ymm3) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm9, ymm9) + + vpermilps(imm(0xb1), ymm12, ymm3) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm12, ymm12) + + vpermilps(imm(0xb1), ymm13, ymm3) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm13, ymm13) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm4, ymm0, ymm0) + CGEMM_OUTPUT_RS + + CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddps(ymm5, ymm0, ymm0) + CGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm8, ymm0, ymm0) + CGEMM_OUTPUT_RS + + CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddps(ymm9, ymm0, ymm0) + CGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm12, ymm0, ymm0) + CGEMM_OUTPUT_RS + + CGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddps(ymm13, ymm0, ymm0) + CGEMM_OUTPUT_RS_NEXT + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|----------------| |-------| + | | | | | + | 3x4 | 3x4 | | 4x3 | + | | | |-------| + |----------------| | | + | 4x3 | + |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm4, ymm0, ymm4) + + add(rdi, rcx) + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm8, ymm0, ymm8) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm12, ymm0, ymm12) + + lea(mem(r12, rsi, 4), rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm5, ymm0, ymm5) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm9, ymm0, ymm9) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm13, ymm0, ymm13) + + mov(r12, rcx) // reset rcx to current utile of c. + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-12 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + vmovlpd(xmm12, mem(rcx, 16)) // store (gamma20) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + vmovhpd(xmm12, mem(rcx, 16)) // store (gamma21) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm0, mem(rcx)) // store (gamma02-12) + vmovlpd(xmm12, mem(rcx, 16)) // store (gamma22) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma03-13) + vmovhpd(xmm12, mem(rcx, 16)) // store (gamma33) + lea(mem(rcx, rsi, 1), rcx) + + /******************Transpose bottom tile 4x3***************************/ + vunpcklpd(ymm9, ymm5, ymm0) //a8a9b8b9 a12a13b12b13 //gamma04-14 gamma06-16 + vunpckhpd(ymm9, ymm5, ymm2) //a10a11b10b11 a14a15b14b15 //gamma05-15 gamma07-17 + + vmovups(xmm0, mem(rcx)) // store (gamma04-14) + vmovlpd(xmm13, mem(rcx, 16)) // store (gamma24) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma05-15) + vmovhpd(xmm13, mem(rcx, 16)) // store (gamma25) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm0, mem(rcx)) // store (gamma06-16) + vmovlpd(xmm13, mem(rcx, 16)) // store (gamma26) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma07-17) + vmovhpd(xmm13, mem(rcx, 16)) // store (gamma27) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****3x8 tile going to save into 8x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm0, mem(rcx)) + vmovlpd(xmm12, mem(rcx,16)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + vmovhpd(xmm12,mem(rcx,16)) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm0, mem(rcx)) + vmovlpd(xmm12, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + vmovhpd(xmm12, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + + /******************Transpose bottom tile 4x3***************************/ + vunpcklpd(ymm9, ymm5, ymm0) //a8a9b8b9 a12a13b12b13 + vunpckhpd(ymm9, ymm5, ymm2) //a10a11b10b11 a14a15b14b15 + + vmovups(xmm0, mem(rcx)) + vmovlpd(xmm13, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + vmovhpd(xmm13, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm0, mem(rcx)) + vmovlpd(xmm13, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + + vmovups(xmm2, mem(rcx)) + vmovhpd(xmm13, mem(rcx, 16)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.SLOOP3X8I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + scomplex* cij = c + i_edge*rs_c; + scomplex* ai = a + i_edge*rs_a; + scomplex* bj = b; + + cgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_cgemmsup_rv_zen_asm_1x8, + bli_cgemmsup_rv_zen_asm_2x8, + }; + + cgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + + } + +} + +void bli_cgemmsup_rv_zen_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + + vbroadcastss(mem(rax, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 1, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 2, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm14) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilps(imm(0xb1), ymm6, ymm6) + vpermilps(imm(0xb1), ymm10, ymm10) + vpermilps(imm(0xb1), ymm14, ymm14) + + // subtract/add even/odd elements + vaddsubps(ymm6, ymm4, ymm4) + + vaddsubps(ymm10, ymm8, ymm8) + + vaddsubps(ymm14, ymm12, ymm12) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), ymm4, ymm3) + vmulps(ymm0, ymm4, ymm4) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm4, ymm4) + + vpermilps(imm(0xb1), ymm8, ymm3) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm8, ymm8) + + vpermilps(imm(0xb1), ymm12, ymm3) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm1, ymm3, ymm3) + vaddsubps(ymm3, ymm12, ymm12) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm4, ymm0, ymm0) + CGEMM_OUTPUT_RS + + add(rdi, rcx) // rcx = c + 1*rs_c + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm8, ymm0, ymm0) + CGEMM_OUTPUT_RS + + add(rdi, rcx) // rcx = c + 2*rs_c + + CGEMM_INPUT_SCALE_RS_BETA_NZ + vaddps(ymm12, ymm0, ymm0) + CGEMM_OUTPUT_RS + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 3x4 | | 4x3 | + |--------| |-------| + */ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm4, ymm0, ymm4) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm8, ymm0, ymm8) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ + vaddps(ymm12, ymm0, ymm12) + + mov(r12, rcx) // reset rcx to current utile of c. + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-12 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + /******************Transpose tile 4x3***************************/ + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + vmovlpd(xmm12, mem(rcx, 16)) // store (gamma20) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + vmovhpd(xmm12, mem(rcx, 16)) // store (gamma21) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm0, mem(rcx)) // store (gamma02-12) + vmovlpd(xmm12, mem(rcx, 16)) // store (gamma22) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma03-13) + vmovhpd(xmm12, mem(rcx, 16)) // store (gamma33) + lea(mem(rcx, rsi, 1), rcx) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****3x4 tile going to save into 4x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vunpcklpd(ymm8, ymm4, ymm0) //a0a1b0b1 a4a4b4b5 + vunpckhpd(ymm8, ymm4, ymm2) //a2a3b2b3 a6a7b6b7 + + vmovups(xmm0, mem(rcx)) + vmovlpd(xmm12, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + vmovhpd(xmm12, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + + vextractf128(imm(0x1), ymm0, xmm0) + vextractf128(imm(0x1), ymm2, xmm2) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm0, mem(rcx)) + vmovlpd(xmm12, mem(rcx, 16)) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) + vmovhpd(xmm12, mem(rcx, 16)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + scomplex* cij = c + i_edge*rs_c; + scomplex* ai = a + i_edge*rs_a; + scomplex* bj = b; + + cgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_cgemmsup_rv_zen_asm_1x4, + bli_cgemmsup_rv_zen_asm_2x4, + }; + + cgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_cgemmsup_rv_zen_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm12) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 2, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm12) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 2, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm12) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 2, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm12) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 2, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm14) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + vbroadcastss(mem(rax, r8, 1), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm12) + + vbroadcastss(mem(rax, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 1, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 2, 4), xmm3) + vfmadd231ps(xmm0, xmm3, xmm14) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of xmm6/7, xmm10/11, xmm/14/15 + vpermilps(imm(0xb1), xmm6, xmm6) + vpermilps(imm(0xb1), xmm10, xmm10) + vpermilps(imm(0xb1), xmm14, xmm14) + + // subtract/add even/odd elements + vaddsubps(xmm6, xmm4, xmm4) + vaddsubps(xmm10, xmm8, xmm8) + vaddsubps(xmm14, xmm12, xmm12) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastss(mem(rax), xmm0) // load alpha_r and duplicate + vbroadcastss(mem(rax, 4), xmm1) // load alpha_i and duplicate + + vpermilps(imm(0xb1), xmm4, xmm3) + vmulps(xmm0, xmm4, xmm4) + vmulps(xmm1, xmm3, xmm3) + vaddsubps(xmm3, xmm4, xmm4) + + vpermilps(imm(0xb1), xmm8, xmm3) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm1, xmm3, xmm3) + vaddsubps(xmm3, xmm8, xmm8) + + vpermilps(imm(0xb1), xmm12, xmm3) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm1, xmm3, xmm3) + vaddsubps(xmm3, xmm12, xmm12) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rbx), xmm1) // load beta_r and duplicate + vbroadcastss(mem(rbx, 4), xmm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomiss(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vpermilps(imm(0xb1), xmm0, xmm3) + vmulps(xmm1, xmm0, xmm0) + vmulps(xmm2, xmm3, xmm3) + vaddsubps(xmm3, xmm0, xmm0) + + vaddps(xmm4, xmm0, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + add(rdi, rcx) // rcx = c + 1*rs_c + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vpermilps(imm(0xb1), xmm0, xmm3) + vmulps(xmm1, xmm0, xmm0) + vmulps(xmm2, xmm3, xmm3) + vaddsubps(xmm3, xmm0, xmm0) + + vaddps(xmm8, xmm0, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + add(rdi, rcx) // rcx = c + 2*rs_c + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vpermilps(imm(0xb1), xmm0, xmm3) + vmulps(xmm1, xmm0, xmm0) + vmulps(xmm2, xmm3, xmm3) + vaddsubps(xmm3, xmm0, xmm0) + + vaddps(xmm12, xmm0, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /*|--------| |-------| + | | | | + | 3x2 | | 2x3 | + | | |-------| + |--------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + CGEMM_INPUT_SCALE_CS_BETA_NZ_128 + vaddps(xmm4, xmm0, xmm4) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ_128 + vaddps(xmm8, xmm0, xmm8) + add(rdi, rcx) + + CGEMM_INPUT_SCALE_CS_BETA_NZ_128 + vaddps(xmm12, xmm0, xmm12) + + mov(r12, rcx) // reset rcx to current utile of c. + vunpcklpd(xmm8, xmm4, xmm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-02 + vunpckhpd(xmm8, xmm4, xmm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + vmovlpd(xmm12, mem(rcx, 16)) // store (gamma20) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + vmovhpd(xmm12, mem(rcx, 16)) // store (gamma21) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****3x2 tile going to save into 2x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + + vunpcklpd(xmm8, xmm4, xmm0) //a0a1b0b1 a4a4b4b5 //gamma00-10 gamma02-02 + vunpckhpd(xmm8, xmm4, xmm2) //a2a3b2b3 a6a7b6b7 //gamma01-11 gamma03-13 + + vmovups(xmm0, mem(rcx)) // store (gamma00-10) + vmovlpd(xmm12, mem(rcx, 16)) // store (gamma20) + lea(mem(rcx, rsi, 1), rcx) + vmovups(xmm2, mem(rcx)) // store (gamma01-11) + vmovhpd(xmm12, mem(rcx, 16)) // store (gamma21) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.SLOOP3X2I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + scomplex* cij = c + i_edge*rs_c; + scomplex* ai = a + i_edge*rs_a; + scomplex* bj = b; + + cgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_cgemmsup_rv_zen_asm_1x2, + bli_cgemmsup_rv_zen_asm_2x2, + }; + + cgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + + \ No newline at end of file diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c new file mode 100644 index 0000000000..6c68707e18 --- /dev/null +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -0,0 +1,1581 @@ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + +/* + rrr: + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + + rcr: + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : +*/ +void bli_cgemmsup_rv_zen_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 3; + if ( m_left ) + { + cgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_cgemmsup_rv_zen_asm_1x8n, + bli_cgemmsup_rv_zen_asm_2x8n, + }; + cgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; + } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 8; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*8; + tC = c + n_iter*tc_inc_col*8; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm5, 0xb1); + ymm5 = _mm256_mul_ps(ymm0, ymm5); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_addsub_ps(ymm5, ymm3); + + ymm3 = _mm256_permute_ps(ymm8, 0xb1); + ymm8 = _mm256_mul_ps(ymm0, ymm8); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_addsub_ps(ymm8, ymm3); + + ymm3 = _mm256_permute_ps(ymm9, 0xb1); + ymm9 = _mm256_mul_ps(ymm0, ymm9); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_addsub_ps(ymm9, ymm3); + + ymm3 = _mm256_permute_ps(ymm12, 0xb1); + ymm12 = _mm256_mul_ps(ymm0, ymm12); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_addsub_ps(ymm12, ymm3); + + ymm3 = _mm256_permute_ps(ymm13, 0xb1); + ymm13 = _mm256_mul_ps(ymm0, ymm13); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm13 = _mm256_addsub_ps(ymm13, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ) ,_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm1,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + + //transpose right 3x4 + tC += tc_inc_col; + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC ),_mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm1,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC) ); + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + //Multiply ymm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm8 = _mm256_add_ps(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm12 = _mm256_add_ps(ymm12, ymm0); + + //transpose left 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm3,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + + //Multiply ymm5 with beta + tC += tc_inc_col; + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm5 = _mm256_add_ps(ymm5, ymm0); + + //Multiply ymm9 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC+ 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC+ 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm9 = _mm256_add_ps(ymm9, ymm0); + + //Multiply ymm13 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm13 = _mm256_add_ps(ymm13, ymm0); + + //transpose right 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ), _mm256_castps256_ps128(ymm3)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm3,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row ), ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2+ 4), ymm13); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_add_ps(ymm5, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_add_ps(ymm8, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row + 4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_add_ps(ymm9, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_add_ps(ymm12, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row*2 +4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm13 = _mm256_add_ps(ymm13, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2+ 4), ymm13); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + scomplex* restrict cij = c + j_edge*cs_c; + scomplex* restrict ai = a; + scomplex* restrict bj = b + n_iter*8; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 1 == n_left ) + { + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } + +} + +void bli_cgemmsup_rv_zen_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 8; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*8; + tC = c + n_iter*tc_inc_col*8; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm5, 0xb1); + ymm5 = _mm256_mul_ps(ymm0, ymm5); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_addsub_ps(ymm5, ymm3); + + ymm3 = _mm256_permute_ps(ymm8, 0xb1); + ymm8 = _mm256_mul_ps(ymm0, ymm8); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_addsub_ps(ymm8, ymm3); + + ymm3 = _mm256_permute_ps(ymm9, 0xb1); + ymm9 = _mm256_mul_ps(ymm0, ymm9); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_addsub_ps(ymm9, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 2x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm1,1)); + + //transpose right 2x4 + tC += tc_inc_col; + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm1,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + //Multiply ymm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1); + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm8 = _mm256_add_ps(ymm8, ymm0); + + //transpose left 2x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm3,1)); + + //Multiply ymm5 with beta + tC += tc_inc_col; + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm5 = _mm256_add_ps(ymm5, ymm0); + + //Multiply ymm9 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC+ 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC+ 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm9 = _mm256_add_ps(ymm9, ymm0); + + //transpose right 2x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm3,1)); + + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_add_ps(ymm5, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_add_ps(ymm8, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row + 4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_add_ps(ymm9, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + } + } + } + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + scomplex* restrict cij = c + j_edge*cs_c; + scomplex* restrict ai = a; + scomplex* restrict bj = b + n_iter * 8 ; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_cgemmsup_rv_zen_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 8; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*8; + tC = c + n_iter*tc_inc_col*8; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm5, 0xb1); + ymm5 = _mm256_mul_ps(ymm0, ymm5); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_addsub_ps(ymm5, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 1x4 + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + //transpose right 1x4 + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + //Multiply ymm5 with beta + tC += tc_inc_col; + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm5 = _mm256_add_ps(ymm5, ymm0); + + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_add_ps(ymm5, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + } + } + } + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + scomplex* restrict cij = c + j_edge*cs_c; + scomplex* restrict ai = a; + scomplex* restrict bj = b + n_iter * 8; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ){ + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + + +void bli_cgemmsup_rv_zen_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm6; + __m256 ymm8, ymm10; + __m256 ymm12, ymm14; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm8, 0xb1); + ymm8 = _mm256_mul_ps(ymm0, ymm8); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_addsub_ps(ymm8, ymm3); + + ymm3 = _mm256_permute_ps(ymm12, 0xb1); + ymm12 = _mm256_mul_ps(ymm0, ymm12); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_addsub_ps(ymm12, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC),_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ) ,_mm256_extractf128_ps (ymm1,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + //Multiply ymm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm8 = _mm256_add_ps(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm12 = _mm256_add_ps(ymm12, ymm0); + + //transpose 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm3,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_add_ps(ymm8, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_add_ps(ymm12, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12);; + } + } +} + +void bli_cgemmsup_rv_zen_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + // clear scratch registers. + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4 = _mm_setzero_ps(); + __m128 xmm6 = _mm_setzero_ps(); + __m128 xmm8 = _mm_setzero_ps(); + __m128 xmm10 = _mm_setzero_ps(); + __m128 xmm12 = _mm_setzero_ps(); + __m128 xmm14 = _mm_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + xmm3 = _mm_permute_ps(xmm4, 0xb1); + xmm4 = _mm_mul_ps(xmm0, xmm4); + xmm3 =_mm_mul_ps(xmm1, xmm3); + xmm4 = _mm_addsub_ps(xmm4, xmm3); + + xmm3 = _mm_permute_ps(xmm8, 0xb1); + xmm8 = _mm_mul_ps(xmm0, xmm8); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm8 = _mm_addsub_ps(xmm8, xmm3); + + xmm3 = _mm_permute_ps(xmm12, 0xb1); + xmm12 = _mm_mul_ps(xmm0, xmm12); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm12 = _mm_addsub_ps(xmm12, xmm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose 3x2 + xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8))); + _mm_storeu_ps((float *)(tC ), xmm0); + _mm_storel_pi((__m64 *)(tC+2), xmm12); + + xmm1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ), xmm1); + _mm_storeh_pi((__m64 *)(tC+2), xmm12); + } + else{ + xmm1 = _mm_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + xmm2 = _mm_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply xmm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_permute_ps(xmm0, 0xb1); + xmm0 = _mm_mul_ps(xmm1, xmm0); + xmm3 = _mm_mul_ps(xmm2, xmm3); + xmm0 = _mm_addsub_ps(xmm0, xmm3); + xmm4 = _mm_add_ps(xmm4, xmm0); + + //Multiply xmm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_permute_ps(xmm0, 0xb1); + xmm0 = _mm_mul_ps(xmm1, xmm0); + xmm3 = _mm_mul_ps(xmm2, xmm3); + xmm0 = _mm_addsub_ps(xmm0, xmm3); + xmm8 = _mm_add_ps(xmm8, xmm0); + + //Multiply xmm12 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ; + xmm3 = _mm_permute_ps(xmm0, 0xb1); + xmm0 = _mm_mul_ps(xmm1, xmm0); + xmm3 = _mm_mul_ps(xmm2, xmm3); + xmm0 = _mm_addsub_ps(xmm0, xmm3); + xmm12 = _mm_add_ps(xmm12, xmm0); + + //transpose 3x2 + xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8))); + _mm_storeu_ps((float *)(tC ), xmm0); + _mm_storel_pi((__m64 *)(tC+2), xmm12); + + xmm3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ), xmm3); + _mm_storeh_pi((__m64 *)(tC+2), xmm12); + + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm_storeu_ps((float *)(tC), xmm4); + _mm_storeu_ps((float *)(tC + tc_inc_row) , xmm8); + _mm_storeu_ps((float *)(tC + tc_inc_row *2), xmm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + xmm0 = _mm_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + xmm1 = _mm_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + xmm2 = _mm_loadu_ps((float const *)(tC)); + xmm3 = _mm_permute_ps(xmm2, 0xb1); + xmm2 = _mm_mul_ps(xmm0, xmm2); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm4 = _mm_add_ps(xmm4, _mm_addsub_ps(xmm2, xmm3)); + + xmm2 = _mm_loadu_ps((float const *)(tC+tc_inc_row)); + xmm3 = _mm_permute_ps(xmm2, 0xb1); + xmm2 = _mm_mul_ps(xmm0, xmm2); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm8 = _mm_add_ps(xmm8, _mm_addsub_ps(xmm2, xmm3)); + + xmm2 = _mm_loadu_ps((float const *)(tC+tc_inc_row*2)); + xmm3 = _mm_permute_ps(xmm2, 0xb1); + xmm2 = _mm_mul_ps(xmm0, xmm2); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm12 = _mm_add_ps(xmm12, _mm_addsub_ps(xmm2, xmm3)); + + _mm_storeu_ps((float *)(tC), xmm4); + _mm_storeu_ps((float *)(tC + tc_inc_row) , xmm8); + _mm_storeu_ps((float *)(tC + tc_inc_row *2), xmm12);; + } + } +} diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c new file mode 100644 index 0000000000..1638eaba0b --- /dev/null +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c @@ -0,0 +1,1658 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. +// outputs to ymm0 +#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ + vmovupd(mem(rcx), xmm0) \ + vmovupd(mem(rcx, rsi, 1), xmm3) \ + vinsertf128(imm(1), xmm3, ymm0, ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) + +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ + vmovupd(mem(rcx), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) + +#define ZGEMM_OUTPUT_RS \ + vmovupd(ymm0, mem(rcx)) \ + +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ + vmovupd(mem(rcx, rsi, 8), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) + +#define ZGEMM_OUTPUT_RS_NEXT \ + vmovupd(ymm0, mem(rcx, rsi, 8)) \ + + +void bli_zgemmsup_rv_zen_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + vbroadcastsd(mem(rax , 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm7, ymm7) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm11, ymm11) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm7, ymm5, ymm5) + + vaddsubpd(ymm10, ymm8, ymm8) + vaddsubpd(ymm11, ymm9, ymm9) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm5, ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm5, ymm5) + + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) + + vpermilpd(imm(0x5), ymm9, ymm3) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm9, ymm9) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 2x4 | | 4x2 | + |--------| |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) + + lea(mem(r12, rsi, 2), rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm5, ymm0, ymm5) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm9, ymm0, ymm9) + add(rdi, rcx) + + mov(r12, rcx) // reset rcx to current utile of c. + + + /****3x4 tile going to save into 4x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****2x4 tile going to save into 4x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose tile 2x4***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + +} + + +void bli_zgemmsup_rv_zen_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + vbroadcastsd(mem(rax , 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm7, ymm7) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm7, ymm5, ymm5) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm5, ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm5, ymm5) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 1x4 | | 4x1 | + |--------| |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + lea(mem(r12, rsi, 2), rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm5, ymm0, ymm5) + + mov(r12, rcx) // reset rcx to current utile of c. + + /****1x4 tile going to save into 4x1 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + vmovups(xmm4, mem(rcx)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovups(xmm4, mem(rcx)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vmovups(xmm5, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + /****1x4 tile going to save into 4x1 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + vmovups(xmm4, mem(rcx)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovups(xmm4, mem(rcx)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vmovups(xmm5, mem(rcx)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + +} + +void bli_zgemmsup_rv_zen_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm10, ymm10) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + + vaddsubpd(ymm10, ymm8, ymm8) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 2x2 | | 2x2 | + |--------| |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + + mov(r12, rcx) // reset rcx to current utile of c. + + /****2x2 tile going to save into 2x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****2x2 tile going to save into 2x2 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_zgemmsup_rv_zen_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + + ) +{ + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + +// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 1x2 | | 2x1 | + |--------| |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + + /****3x4 tile going to save into 4x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose tile 1x2***************************/ + vmovups(xmm4, mem(rcx)) + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovups(xmm4, mem(rcx)) + + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****1x2 tile going to save into 2x1 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovups(xmm4, mem(rcx)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + +} + + diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c new file mode 100644 index 0000000000..05e05dfece --- /dev/null +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -0,0 +1,1229 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. +// outputs to ymm0 +#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ + vmovupd(mem(rcx), xmm0) \ + vmovupd(mem(rcx, rsi, 1), xmm3) \ + vinsertf128(imm(1), xmm3, ymm0, ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) + +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ + vmovupd(mem(rcx), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) + +#define ZGEMM_OUTPUT_RS \ + vmovupd(ymm0, mem(rcx)) \ + +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ + vmovupd(mem(rcx, rsi, 8), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) + +#define ZGEMM_OUTPUT_RS_NEXT \ + vmovupd(ymm0, mem(rcx, rsi, 8)) \ + +/* + rrr: + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + + rcr: + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | += ------ + -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : +*/ +void bli_zgemmsup_rv_zen_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 4; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 3x?m kernels, as needed. + if (n_left ) + { + dcomplex* cij = c; + dcomplex* bj = b; + dcomplex* ai = a; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(real dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof((real + imag) dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + vbroadcastsd(mem(rax, 8 ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + vbroadcastsd(mem(rax, 8 ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm7, ymm7) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm11, ymm11) + vpermilpd(imm(0x5), ymm14, ymm14) + vpermilpd(imm(0x5), ymm15, ymm15) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm7, ymm5, ymm5) + + vaddsubpd(ymm10, ymm8, ymm8) + vaddsubpd(ymm11, ymm9, ymm9) + + vaddsubpd(ymm14, ymm12, ymm12) + vaddsubpd(ymm15, ymm13, ymm13) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm5, ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm5, ymm5) + + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) + + vpermilpd(imm(0x5), ymm9, ymm3) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm9, ymm9) + + vpermilpd(imm(0x5), ymm12, ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm12, ymm12) + + vpermilpd(imm(0x5), ymm13, ymm3) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm13, ymm13) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 3x4 | | 4x3 | + |--------| |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm12, ymm0, ymm12) + + lea(mem(r12, rsi, 2), rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm5, ymm0, ymm5) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm9, ymm0, ymm9) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm13, ymm0, ymm13) + + mov(r12, rcx) // reset rcx to current utile of c. + + + /****3x4 tile going to save into 4x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****3x4 tile going to save into 4x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.SLOOP3X8I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + dcomplex* cij = c + i_edge*rs_c; + dcomplex* ai = a + i_edge*rs_a; + dcomplex* bj = b; + + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4, + bli_zgemmsup_rv_zen_asm_2x4, + }; + + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + + } + +} + +void bli_zgemmsup_rv_zen_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + +// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vzeroall() // zero all xmm/ymm registers. + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm14, ymm14) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + + vaddsubpd(ymm10, ymm8, ymm8) + + vaddsubpd(ymm14, ymm12, ymm12) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) + + vpermilpd(imm(0x5), ymm12, ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm12, ymm12) + + /* (ßr + ßi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + /*|--------| |-------| + | | | | + | 3x2 | | 2x3 | + |--------| |-------| + */ + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm12, ymm0, ymm12) + + mov(r12, rcx) // reset rcx to current utile of c. + + /****3x2 tile going to save into 2x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose top tile 2x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****3x2 tile going to save into 2x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + + /******************Transpose tile 3x2***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.SLOOP3X8I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + dcomplex* cij = c + i_edge*rs_c; + dcomplex* ai = a + i_edge*rs_a; + dcomplex* bj = b; + + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x2, + bli_zgemmsup_rv_zen_asm_2x2, + }; + + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c new file mode 100644 index 0000000000..872d048685 --- /dev/null +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -0,0 +1,1196 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + +/* + rrr: + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + + rcr: + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : +*/ +void bli_zgemmsup_rv_zen_asm_3x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 3; + if ( m_left ) + { + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4n, + bli_zgemmsup_rv_zen_asm_2x4n, + }; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; + } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); + + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); + + ymm3 = _mm256_permute_pd(ymm13, 5); + ymm13 = _mm256_mul_pd(ymm0, ymm13); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_addsub_pd(ymm13, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; + + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); + + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); + + //Multiply ymm13 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm13 = _mm256_add_pd(ymm13, ymm0); + + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_add_pd(ymm13, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } + +} + +void bli_zgemmsup_rv_zen_asm_2x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + uint64_t k_iter = 0; + + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; + + //transpose right 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //transpose left 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; + + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); + + //transpose right 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } + +} + +void bli_zgemmsup_rv_zen_asm_1x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; + + //transpose right 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + bli_zgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_zgemmsup_rv_zen_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm6; + __m256d ymm8, ymm10; + __m256d ymm12, ymm14; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row ), ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + } +} diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c new file mode 100644 index 0000000000..96bc927499 --- /dev/null +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c @@ -0,0 +1,2668 @@ +/* + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ +void bli_sgemmsup_rd_zen_asm_2x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // ------------------------------------------------------------------------- + begin_asm() + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + lea(mem(r14), rax) // rax = a; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + // ---------------------------------- iteration 2 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vhaddps(xmm2,xmm0,xmm4) + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + vhaddps(xmm2,xmm0,xmm5) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vfmadd231ps(mem(rcx), xmm3,xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vmovups(xmm5, mem(rcx)) + + label(.SDONE) + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_1x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // ------------------------------------------------------------------------- + begin_asm() + mov(var(a), r14) // load address of a. + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + mov(var(c), r12) // load address of c + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + lea(mem(r14), rax) // rax = a; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 2 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vhaddps(xmm2,xmm0,xmm4) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(xmm4, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovups(xmm4, mem(rcx)) + + label(.SDONE) + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // ------------------------------------------------------------------------- + begin_asm() + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + lea(mem(r14), rax) // rax = a; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 2 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + vhaddps(xmm2,xmm0,xmm5) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vfmadd231ps(mem(rcx), xmm3,xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + //add(rdi, rcx) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vmovups(xmm5, mem(rcx)) + + label(.SDONE) + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // ------------------------------------------------------------------------- + begin_asm() + mov(var(a), r14) // load address of a. + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + mov(var(c), r12) // load address of c + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + lea(mem(r14), rax) // rax = a; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 2 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vhaddps(xmm2,xmm0,xmm4) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vfmadd231ps(mem(rcx), xmm3,xmm4) + vmovups(xmm4, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovups(xmm4, mem(rcx)) + + label(.SDONE) + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // ------------------------------------------------------------------------- + begin_asm() + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 2 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + vhaddps(xmm2,xmm0,xmm5) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vfmadd231ps(mem(rcx), xmm3,xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vmovups(xmm5, mem(rcx)) + label(.SDONE) + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} +void bli_sgemmsup_rd_zen_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) + mov(var(a), rax) // load address of a. + + mov(var(b), rbx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + mov(var(c), rcx) // load address of c + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + // ---------------------------------- iteration 2 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vhaddps(xmm2,xmm0,xmm4) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vfmadd231ps(mem(rcx), xmm3,xmm4) + vmovups(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + label(.SROWSTORBZ) + vmovups(xmm4, mem(rcx)) + + label(.SDONE) + + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + // initialize loop by pre-loading + // a column of a. + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + // ---------------------------------- iteration 1 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovss(mem(rax, r8, 1), xmm3) + add(imm(1*4), rax) // a += 1*cs_a = 1*8; + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm5 + // ymm6 ymm7 + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm4) + vhaddps( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps(xmm0,xmm0,xmm6) // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm6, mem(rcx)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + + mov(var(a), rax) // load address of a. + + mov(var(b), rbx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + // initialize loop by pre-loading + // a column of a. + mov(var(c), rcx) // load address of c + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + // ---------------------------------- iteration 1 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vmovss(mem(rax ), xmm3) + add(imm(1*4), rax) // a += 1*cs_a = 1*8; + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm5 + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm4) // xmm4 = sum(ymm4) sum(ymm5) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx))//a0a1 + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovsd(xmm4, mem(rcx)) + + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rd_zen_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + if ( m_iter == 0 ) goto consider_edge_cases; + // ------------------------------------------------------------------------- + begin_asm() + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle, + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 1 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + vmovups(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovups(mem(rax, r8, 4), ymm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + vmovups(mem(rax, r15, 1), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovss(mem(rax, r8, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + vmovss(mem(rax, r13, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vmovss(mem(rax, r8, 4), xmm3) + vfmadd231ps(ymm0, ymm3, ymm12) + vfmadd231ps(ymm1, ymm3, ymm13) + vmovss(mem(rax, r15, 1), xmm3) + add(imm(1*4), rax) // a += 1*cs_a = 1*8; + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm4) + vhaddps( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm6) + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm8) + vhaddps( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm10) + vhaddps( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm12) + vhaddps( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm14) + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm6)//c*beta+(a0a1) + vmovsd(xmm6, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm8)//c*beta+(a0a1) + vmovsd(xmm8, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm10)//c*beta+(a0a1) + vmovsd(xmm10, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm12)//c*beta+(a0a1) + vmovsd(xmm12, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm14)//c*beta+(a0a1) + vmovsd(xmm14, mem(rcx))//a0a1 + //add(rdi, rcx) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm14, mem(rcx)) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + bli_sgemmsup_rd_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + bli_sgemmsup_rd_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + bli_sgemmsup_rd_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_zen_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + vzeroall() // zero all xmm/ymm registers. + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + // initialize loop by pre-loading + // a column of a. + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + // ---------------------------------- iteration 0 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + // ---------------------------------- iteration 1 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + vmovups(mem(rbx ), ymm0) + vmovups(mem(rbx, r11, 1), ymm1) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vmovups(mem(rax ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovups(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovups(mem(rax, r8, 2), ymm3) + add(imm(8*4), rax) // a += 4*cs_a = 4*8; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovss(mem(rbx ), xmm0) + vmovss(mem(rbx, r11, 1), xmm1) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vmovss(mem(rax ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vmovss(mem(rax, r8, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + vmovss(mem(rax, r8, 2), xmm3) + add(imm(1*4), rax) // a += 1*cs_a = 1*8; + vfmadd231ps(ymm0, ymm3, ymm8) + vfmadd231ps(ymm1, ymm3, ymm9) + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + vhaddps( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + vhaddps(xmm0,xmm0,xmm4) + vhaddps( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps(xmm0,xmm0,xmm6) + vhaddps( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps(xmm0,xmm0,xmm8) // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) + + label(.SROWSTORBZ) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm8, mem(rcx)) + + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c new file mode 100644 index 0000000000..4eebb2b0a5 --- /dev/null +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c @@ -0,0 +1,1965 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. + + NOTE: These kernels implicitly support column-oriented IO, implemented + via an a high-level transposition of the entire operation. A and B will + effectively remain row- and column-stored, respectively, but C will then + effectively appear column-stored. Thus, this kernel may be used for both + rrc and crc cases. +*/ + +// Prototype reference microkernels. + +void bli_sgemmsup_rd_zen_asm_6x16m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + uint64_t n_left = n0 % 16; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rd_zen_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rd_zen_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_zen_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + return; + } + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + mov(var(b), rdx) + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_b = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm5) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm6) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + add(imm(4), r15) // jj += 4; + cmp(imm(16), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 16; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_zen_asm_2x16 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_zen_asm_1x16 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_zen_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + mov(var(b), rdx) + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_b = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm5) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm6) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_zen_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_zen_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + + + +void bli_sgemmsup_rd_zen_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + mov(var(b), rdx) + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_b = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm5) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm6) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_zen_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_zen_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rd_zen_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), rdx) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.SLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + mov(var(b), rdx) + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*4), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.SLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + + // ---------------------------------- iteration 1 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + // ---------------------------------- iteration 3 + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_b = 1*4; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + add(imm(1*4), rbx) // b += 1*rs_b = 1*4; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 + // ymm5 ymm8 + // ymm6 ymm9 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vhaddps(xmm0,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps(xmm0,xmm0,xmm5) + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + vhaddps(xmm0,xmm0,xmm6) + // ymm4 = sum(ymm4) sum(ymm7) + // ymm5 = sum(ymm5) sum(ymm8) + // ymm6 = sum(ymm6) sum(ymm9) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx))//a0a1 + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm5) + vmovsd(xmm5, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm5, mem(rcx)) + add(rdi, rcx) + vmovsd(xmm6, mem(rcx)) + + label(.SDONE) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.SLOOP3X4I) // iterate again if ii != 0. + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jl(.SLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict bj = b; + float* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_sgemmsup_rd_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c new file mode 100644 index 0000000000..7f0c856130 --- /dev/null +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c @@ -0,0 +1,1869 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. + + NOTE: These kernels implicitly support column-oriented IO, implemented + via an a high-level transposition of the entire operation. A and B will + effectively remain row- and column-stored, respectively, but C will then + effectively appear column-stored. Thus, this kernel may be used for both + rrc and crc cases. +*/ + +void bli_sgemmsup_rd_zen_asm_6x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m0 == 7 ) + { + mr1 = 6; mr2 = 1; + ker_fp1 = bli_sgemmsup_rd_zen_asm_6x16n; + ker_fp2 = bli_sgemmsup_rd_zen_asm_1x16n; + } + else if ( m0 == 8 ) + { + mr1 = 6; mr2 = 2; + ker_fp1 = bli_sgemmsup_rd_zen_asm_6x16n; + ker_fp2 = bli_sgemmsup_rd_zen_asm_2x16n; + } + else // if ( m0 == 9 ) + { + mr1 = 6; mr2 = 3; + ker_fp1 = bli_sgemmsup_rd_zen_asm_6x16n; + ker_fp2 = bli_sgemmsup_rd_zen_asm_3x16n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_sgemmsup_rd_zen_asm_3x16n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_sgemmsup_rd_zen_asm_2x16n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + bli_sgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); + } + return; + } + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.SLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + mov(var(b), r14) // load address of b + mov(var(c), r12) // load address of c + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c + lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(rdx, rsi, 1), rdx) // rax = a + 3*ii*rs_a; + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm5) + + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm6) + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + + label(.SDONE) + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.SLOOP3X4I) // if ii <= 3, jump to beginning + // of ii loop; otherwise, loop ends. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_zen_asm_6x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_sgemmsup_rd_zen_asm_3x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + + mov(var(b), r14) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + mov(var(c), r12) // load address of c + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + vmovups(mem(rax, r8, 2), ymm2) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + vmovss(mem(rax, r8, 2), xmm2) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + vfmadd231ps(ymm2, ymm3, ymm6) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + vfmadd231ps(ymm2, ymm3, ymm9) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + vfmadd231ps(ymm2, ymm3, ymm12) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + vfmadd231ps(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm5) + + + vhaddps( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + + vhaddps( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm6) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + + label(.SDONE) + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_sgemmsup_rd_zen_asm_2x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + mov(var(b), r14) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + mov(var(c), r12) // load address of c + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + vmovups(mem(rax, r8, 1), ymm1) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + vmovss(mem(rax, r8, 1), xmm1) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + vfmadd231ps(ymm1, ymm3, ymm5) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + vfmadd231ps(ymm1, ymm3, ymm8) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + vfmadd231ps(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + vhaddps( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) + + vhaddps( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) + + vhaddps(xmm2,xmm0,xmm5) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm5, xmm5) + vmulps(xmm0, xmm6, xmm6) + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm5) + vmovups(xmm5, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm5, mem(rcx)) + + label(.SDONE) + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_sgemmsup_rd_zen_asm_1x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter32 = k0 / 32; + uint64_t k_left32 = k0 % 32; + uint64_t k_iter8 = k_left32 / 8; + uint64_t k_left1 = k_left32 % 8; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), rdx) // load address of a. + + mov(var(b), r14) // load address of b. + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + mov(var(c), r12) // load address of c + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.SLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + // zen2 can execute 4 vxorpd ipc with + // a latency of 1 cycle + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm13, ymm13, ymm13) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + mov(var(k_iter32), rsi) // i = k_iter32; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKITER8) // if i == 0, jump to code that + // contains the k_iter8 loop. + + label(.SLOOPKITER32) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + // ---------------------------------- iteration 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + // ---------------------------------- iteration 2 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + // ---------------------------------- iteration 3 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER32) // iterate again if i != 0. + + label(.SCONSIDKITER8) + + mov(var(k_iter8), rsi) // i = k_iter8; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter8 loop. + + label(.SLOOPKITER8) // EDGE LOOP (ymm) + + vmovups(mem(rax ), ymm0) + add(imm(8*4), rax) // a += 4*cs_b = 4*8; + + vmovups(mem(rbx ), ymm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovups(mem(rbx, r11, 1), ymm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovups(mem(rbx, r11, 2), ymm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovups(mem(rbx, r13, 1), ymm3) + add(imm(8*4), rbx) // b += 4*rs_b = 4*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER8) // iterate again if i != 0. + + label(.SCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.SLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovss(mem(rax ), xmm0) + add(imm(1*4), rax) // a += 1*cs_b = 1*8; + + vmovss(mem(rbx ), xmm3) + vfmadd231ps(ymm0, ymm3, ymm4) + + vmovss(mem(rbx, r11, 1), xmm3) + vfmadd231ps(ymm0, ymm3, ymm7) + + vmovss(mem(rbx, r11, 2), xmm3) + vfmadd231ps(ymm0, ymm3, ymm10) + + vmovss(mem(rbx, r13, 1), xmm3) + add(imm(1*4), rbx) // b += 1*rs_b = 1*8; + vfmadd231ps(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT1) // iterate again if i != 0. + + label(.SPOSTACCUM) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddps( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddps( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddps( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddps( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vhaddps(xmm2,xmm0,xmm4) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + + label(.SDONE) + + add(imm(4*4), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.SLOOP3X4J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter32] "m" (k_iter32), + [k_iter8] "m" (k_iter8), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rd_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_sdotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); + } + } +} + diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c new file mode 100644 index 0000000000..6c9f8cabe1 --- /dev/null +++ b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c @@ -0,0 +1,8745 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref ) + +void bli_sgemmsup_rv_zen_asm_5x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm13) + vmovups(ymm13, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm12, xmm0)//e0-e3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm12, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + vextractf128(imm(0x0), ymm13, xmm0)//e0-e3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm13, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm12, xmm0)//e0-e3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm12, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + vextractf128(imm(0x0), ymm13, xmm0)//e0-e3 + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm13, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + + jmp(.SDONE) // jump to end. + + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm8, xmm0)//c0-c3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm11) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm8, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + /********************************************/ + vextractf128(imm(0x0), ymm9, xmm0)//c0-c3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm9, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + jmp(.SDONE) // jump to end. + + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm8, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm8, xmm0)//c4-c7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + /********************************************/ + vextractf128(imm(0x0), ymm9, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm9, xmm0)//c4-c7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x16 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vextractf128(imm(0x0), ymm4, xmm0)//c0-c3 + vmovss(mem(rcx),xmm7) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm11) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm7, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + vextractf128(imm(0x1), ymm4, xmm0)//e4-e7 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm8) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vextractf128(imm(0x0), ymm5, xmm0)//c0-c3 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm11) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vextractf128(imm(0x1), ymm5, xmm0)//e4-e7 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm8) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vextractf128(imm(0x0), ymm4, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm4, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + vextractf128(imm(0x0), ymm5, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + vextractf128(imm(0x1), ymm5, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + label(.SDONE) + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*16), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*16), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*16), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*16), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*16), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x8 tile is transposed and saved in col major as 8x6*****/ + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vpermilpd(imm(1),xmm0,xmm5)//e1f1 + vpermilpd(imm(1),xmm2,xmm6)//e5f5 + vfmadd231ps(mem(rdx), xmm3, xmm0) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vfmadd231ps(mem(rdx), xmm3, xmm5) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vpermilpd(imm(1),xmm0,xmm5) + vpermilpd(imm(1),xmm2,xmm6) + vfmadd231ps(mem(rdx), xmm3, xmm0) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vfmadd231ps(mem(rdx), xmm3, xmm5) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) + vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + /******************top right tile 8x2***************************/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_5x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm12, xmm0)//e0-e3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm12, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm12, xmm0)//e0-e3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm12, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + /********************************************/ + vextractf128(imm(0x0), ymm8, xmm0)//c0-c3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm11) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rcx += cs_c + + vextractf128(imm(0x1), ymm8, xmm0)//c0-c3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm11) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + jmp(.SDONE) // jump to end. + + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + /********************************************/ + vextractf128(imm(0x0), ymm8, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm8, xmm0)//c4-c7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + label(.SDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(xmm0,xmm0,xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(xmm0,xmm0,xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /********************************************/ + vextractf128(imm(0x0), ymm4, xmm0)//c0-c3 + vmovss(mem(rcx),xmm8) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm11) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm8, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm4, xmm0)//e4-e7 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm8) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vextractf128(imm(0x0), ymm4, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm4, xmm0)//c4-c7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovups(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm14) + vmovups(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklpd(xmm1, xmm0, xmm2)//a0b0c0d0 + vunpckhpd(xmm1, xmm0, xmm5)//a1b1c1d1 + + vfmadd231ps(mem(rcx), xmm3, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm5) + vmovups(xmm2, mem(rcx)) + vmovups(xmm5, mem(rcx, rsi, 1)) + lea(mem(rcx, rsi, 2), rcx) // rcx += 2*cs_c + + vunpckhps(xmm6, xmm4, xmm0)//a2b2a3b3 + vunpckhps(xmm10, xmm8, xmm1)//c2d2c3d3 + vunpcklpd(xmm1, xmm0, xmm7)//a2b2c2d2 + vunpckhpd(xmm1, xmm0, xmm9)//a3b3c3d3 + + vfmadd231ps(mem(rcx), xmm3, xmm7) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm9) + vmovups(xmm7, mem(rcx)) + vmovups(xmm9, mem(rcx, rsi, 1)) + + vunpcklps(xmm14, xmm12, xmm0)//e0f0e1f1 + vunpckhps(xmm14, xmm12, xmm1)//e2f2e3f3 + vmovsd(mem(rdx),xmm2) + vmovsd(mem(rdx, rsi, 1),xmm4) + vmovsd(mem(rdx, rsi, 2),xmm6) + vmovsd(mem(rdx, rax, 1),xmm8) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//e1f1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//e3f3 + vfmadd231ps(xmm2, xmm3, xmm0) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm6, xmm3, xmm1) + vfmadd231ps(xmm8, xmm3, xmm7) + vmovsd(xmm0, mem(rdx)) //e0f0 + vmovsd(xmm5, mem(rdx, rsi, 1)) //e1f1 + vmovsd(xmm1, mem(rdx, rsi, 2)) //e2f2 + vmovsd(xmm7, mem(rdx, rax, 1)) //e3f3 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklpd(xmm1, xmm0, xmm2)//a0b0c0d0 + vunpckhpd(xmm1, xmm0, xmm5)//a1b1c1d1 + + vmovups(xmm2, mem(rcx)) + vmovups(xmm5, mem(rcx, rsi, 1)) + lea(mem(rcx, rsi, 2), rcx) // rcx += 2*cs_c + + vunpckhps(xmm6, xmm4, xmm0)//a2b2a3b3 + vunpckhps(xmm10, xmm8, xmm1)//c2d2c3d3 + vunpcklpd(xmm1, xmm0, xmm7)//a2b2c2d2 + vunpckhpd(xmm1, xmm0, xmm9)//a3b3c3d3 + + vmovups(xmm7, mem(rcx)) + vmovups(xmm9, mem(rcx, rsi, 1)) + + vunpcklps(xmm14, xmm12, xmm0)//e0f0e1f1 + vunpckhps(xmm14, xmm12, xmm1)//e2f2e3f3 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//e1f1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//e3f3 + vmovsd(xmm0, mem(rdx)) //e0f0 + vmovsd(xmm5, mem(rdx, rsi, 1)) //e1f1 + vmovsd(xmm1, mem(rdx, rsi, 2)) //e2f2 + vmovsd(xmm7, mem(rdx, rax, 1)) //e3f3 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_5x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(r9, rsi) // rsi = rs_b; + sal(imm(5), rsi) // rsi = 16*rs_b; + lea(mem(rax, rsi, 1), rdx) // rdx = b + 16*rs_b; + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovups(xmm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklpd(xmm1, xmm0, xmm2)//a0b0c0d0 + vunpckhpd(xmm1, xmm0, xmm5)//a1b1c1d1 + + vfmadd231ps(mem(rcx), xmm3, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm5) + vmovups(xmm2, mem(rcx)) + vmovups(xmm5, mem(rcx, rsi, 1)) + lea(mem(rcx, rsi, 2), rcx) // rcx += 2*cs_c + + vunpckhps(xmm6, xmm4, xmm0)//a2b2a3b3 + vunpckhps(xmm10, xmm8, xmm1)//c2d2c3d3 + vunpcklpd(xmm1, xmm0, xmm7)//a2b2c2d2 + vunpckhpd(xmm1, xmm0, xmm9)//a3b3c3d3 + + vfmadd231ps(mem(rcx), xmm3, xmm7) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm9) + vmovups(xmm7, mem(rcx)) + vmovups(xmm9, mem(rcx, rsi, 1)) + + vmovss(mem(rdx),xmm2) + vmovss(mem(rdx, rsi, 1),xmm4) + vmovss(mem(rdx, rsi, 2),xmm6) + vmovss(mem(rdx, rax, 1),xmm8) + vshufps(imm(0x01), xmm12, xmm12,xmm1) + vshufps(imm(0x02), xmm12, xmm12,xmm5) + vshufps(imm(0x03), xmm12, xmm12,xmm7) + vfmadd231ps(xmm2, xmm3, xmm12) + vfmadd231ps(xmm4, xmm3, xmm1) + vfmadd231ps(xmm6, xmm3, xmm5) + vfmadd231ps(xmm8, xmm3, xmm7) + vmovss(xmm12, mem(rdx)) //e0 + vmovss(xmm1, mem(rdx, rsi, 1)) //e1 + vmovss(xmm5, mem(rdx, rsi, 2)) //e2 + vmovss(xmm7, mem(rdx, rax, 1)) //e3 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm12, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklpd(xmm1, xmm0, xmm2)//a0b0c0d0 + vunpckhpd(xmm1, xmm0, xmm5)//a1b1c1d1 + + vmovups(xmm2, mem(rcx)) + vmovups(xmm5, mem(rcx, rsi, 1)) + lea(mem(rcx, rsi, 2), rcx) // rcx += 2*cs_c + + vunpckhps(xmm6, xmm4, xmm0)//a2b2a3b3 + vunpckhps(xmm10, xmm8, xmm1)//c2d2c3d3 + vunpcklpd(xmm1, xmm0, xmm7)//a2b2c2d2 + vunpckhpd(xmm1, xmm0, xmm9)//a3b3c3d3 + vmovups(xmm7, mem(rcx)) + vmovups(xmm9, mem(rcx, rsi, 1)) + + vshufps(imm(0x01), xmm12, xmm12,xmm1) + vshufps(imm(0x02), xmm12, xmm12,xmm5) + vshufps(imm(0x03), xmm12, xmm12,xmm7) + vmovss(xmm12, mem(rdx)) //e0 + vmovss(xmm1, mem(rdx, rsi, 1)) //e1 + vmovss(xmm5, mem(rdx, rsi, 2)) //e2 + vmovss(xmm7, mem(rdx, rax, 1)) //e3 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovups(xmm10, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklpd(xmm1, xmm0, xmm2)//a0b0c0d0 + vunpckhpd(xmm1, xmm0, xmm5)//a1b1c1d1 + + vfmadd231ps(mem(rcx), xmm3, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm5) + vmovups(xmm2, mem(rcx)) + vmovups(xmm5, mem(rcx, rsi, 1)) + lea(mem(rcx, rsi, 2), rcx) // rcx += 2*cs_c + + vunpckhps(xmm6, xmm4, xmm0)//a2b2a3b3 + vunpckhps(xmm10, xmm8, xmm1)//c2d2c3d3 + vunpcklpd(xmm1, xmm0, xmm7)//a2b2c2d2 + vunpckhpd(xmm1, xmm0, xmm9)//a3b3c3d3 + + vfmadd231ps(mem(rcx), xmm3, xmm7) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm9) + vmovups(xmm7, mem(rcx)) + vmovups(xmm9, mem(rcx, rsi, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm10, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklpd(xmm1, xmm0, xmm2)//a0b0c0d0 + vunpckhpd(xmm1, xmm0, xmm5)//a1b1c1d1 + vmovups(xmm2, mem(rcx)) + vmovups(xmm5, mem(rcx, rsi, 1)) + lea(mem(rcx, rsi, 2), rcx) // rcx += 2*cs_c + + vunpckhps(xmm6, xmm4, xmm0)//a2b2a3b3 + vunpckhps(xmm10, xmm8, xmm1)//c2d2c3d3 + vunpcklpd(xmm1, xmm0, xmm7)//a2b2c2d2 + vunpckhpd(xmm1, xmm0, xmm9)//a3b3c3d3 + vmovups(xmm7, mem(rcx)) + vmovups(xmm9, mem(rcx, rsi, 1)) + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovups(xmm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//e0f0e1f1 + vunpckhps(xmm6, xmm4, xmm1)//e2f2e3f3 + vmovsd(mem(rcx),xmm2) + vmovsd(mem(rcx, rsi, 1),xmm4) + vmovsd(mem(rcx, rsi, 2),xmm6) + vmovsd(mem(rcx, rax, 1),xmm10) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//e1f1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//e3f3 + vfmadd231ps(xmm2, xmm3, xmm0) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm6, xmm3, xmm1) + vfmadd231ps(xmm10, xmm3, xmm7) + vmovsd(xmm0, mem(rcx)) //e0f0 + vmovsd(xmm5, mem(rcx, rsi, 1)) //e1f1 + vmovsd(xmm1, mem(rcx, rsi, 2)) //e2f2 + vmovsd(xmm7, mem(rcx, rax, 1)) //e3f3 + + vmovss(mem(rdx),xmm2) + vmovss(mem(rdx, rsi, 1),xmm4) + vmovss(mem(rdx, rsi, 2),xmm6) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm8, xmm8,xmm1) + vshufps(imm(0x02), xmm8, xmm8,xmm5) + vshufps(imm(0x03), xmm8, xmm8,xmm7) + vfmadd231ps(xmm2, xmm3, xmm8) + vfmadd231ps(xmm4, xmm3, xmm1) + vfmadd231ps(xmm6, xmm3, xmm5) + vfmadd231ps(xmm10, xmm3, xmm7) + vmovss(xmm8, mem(rdx)) //e0 + vmovss(xmm1, mem(rdx, rsi, 1)) //e1 + vmovss(xmm5, mem(rdx, rsi, 2)) //e2 + vmovss(xmm7, mem(rdx, rax, 1)) //e3 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovups(xmm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//e0f0e1f1 + vunpckhps(xmm6, xmm4, xmm1)//e2f2e3f3 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//e1f1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//e3f3 + vmovsd(xmm0, mem(rcx)) //e0f0 + vmovsd(xmm5, mem(rcx, rsi, 1)) //e1f1 + vmovsd(xmm1, mem(rcx, rsi, 2)) //e2f2 + vmovsd(xmm7, mem(rcx, rax, 1)) //e3f3 + + vshufps(imm(0x01), xmm8, xmm8,xmm1) + vshufps(imm(0x02), xmm8, xmm8,xmm5) + vshufps(imm(0x03), xmm8, xmm8,xmm7) + vmovss(xmm8, mem(rdx)) //e0 + vmovss(xmm1, mem(rdx, rsi, 1)) //e1 + vmovss(xmm5, mem(rdx, rsi, 2)) //e2 + vmovss(xmm7, mem(rdx, rax, 1)) //e3 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//e0f0e1f1 + vunpckhps(xmm6, xmm4, xmm1)//e2f2e3f3 + vmovsd(mem(rcx),xmm2) + vmovsd(mem(rcx, rsi, 1),xmm4) + vmovsd(mem(rcx, rsi, 2),xmm6) + vmovsd(mem(rcx, rax, 1),xmm10) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//e1f1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//e3f3 + vfmadd231ps(xmm2, xmm3, xmm0) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm6, xmm3, xmm1) + vfmadd231ps(xmm10, xmm3, xmm7) + vmovsd(xmm0, mem(rcx)) //e0f0 + vmovsd(xmm5, mem(rcx, rsi, 1)) //e1f1 + vmovsd(xmm1, mem(rcx, rsi, 2)) //e2f2 + vmovsd(xmm7, mem(rcx, rax, 1)) //e3f3 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vmovups(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//e0f0e1f1 + vunpckhps(xmm6, xmm4, xmm1)//e2f2e3f3 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//e1f1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//e3f3 + vmovsd(xmm0, mem(rcx)) //e0f0 + vmovsd(xmm5, mem(rcx, rsi, 1)) //e1f1 + vmovsd(xmm1, mem(rcx, rsi, 2)) //e2f2 + vmovsd(xmm7, mem(rcx, rax, 1)) //e3f3 + + label(.SDONE) + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vmovss(mem(rcx),xmm2) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm8) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm4, xmm4,xmm1) + vshufps(imm(0x02), xmm4, xmm4,xmm5) + vshufps(imm(0x03), xmm4, xmm4,xmm7) + vfmadd231ps(xmm2, xmm3, xmm4) + vfmadd231ps(xmm6, xmm3, xmm1) + vfmadd231ps(xmm8, xmm3, xmm5) + vfmadd231ps(xmm10, xmm3, xmm7) + vmovss(xmm4, mem(rcx)) //e0 + vmovss(xmm1, mem(rcx, rsi, 1)) //e1 + vmovss(xmm5, mem(rcx, rsi, 2)) //e2 + vmovss(xmm7, mem(rcx, rax, 1)) //e3 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vshufps(imm(0x01), xmm4, xmm4,xmm1) + vshufps(imm(0x02), xmm4, xmm4,xmm5) + vshufps(imm(0x03), xmm4, xmm4,xmm7) + vmovss(xmm4, mem(rcx)) //e0 + vmovss(xmm1, mem(rcx, rsi, 1)) //e1 + vmovss(xmm5, mem(rcx, rsi, 2)) //e2 + vmovss(xmm7, mem(rcx, rax, 1)) //e3 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx))//a0a1 + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) + vmovsd(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklps(xmm14, xmm12, xmm2)//e0f0 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rdi, 2),xmm6) + vmovsd(mem(rcx, rdi, 4),xmm8) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//c1d1 + vshufpd(imm(0x01), xmm2, xmm2, xmm9)//e1f1 + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vfmadd231ps(xmm8, xmm3, xmm2) + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm1, mem(rcx, rdi, 2)) //c0d0 + vmovsd(xmm2, mem(rcx, rdi, 4)) //e0f0 + lea(mem(rcx, rsi, 1), rcx) + + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rdi, 2),xmm6) + vmovsd(mem(rcx, rdi, 4),xmm8) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm6, xmm3, xmm7) + vfmadd231ps(xmm8, xmm3, xmm9) + vmovsd(xmm5, mem(rcx)) //a1b1 + vmovsd(xmm7, mem(rcx, rdi, 2)) //c1d1 + vmovsd(xmm9, mem(rcx, rdi, 4)) //e1f1 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vunpcklps(xmm14, xmm12, xmm2)//e0f0 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//c1d1 + vshufpd(imm(0x01), xmm2, xmm2, xmm9)//e1f1 + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm1, mem(rcx, rdi, 2)) //c0d0 + vmovsd(xmm2, mem(rcx, rdi, 4)) //e0f0 + lea(mem(rcx, rsi, 1), rcx) + vmovsd(xmm5, mem(rcx)) //e0f0 + vmovsd(xmm7, mem(rcx, rdi, 2)) //e1f1 + vmovsd(xmm9, mem(rcx, rdi, 4)) //e0f0 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_5x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx))//a0a1 + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rdi, 2),xmm6) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//c1d1 + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm1, mem(rcx, rdi, 2)) //c0d0 + + vmovss(mem(rcx, rdi, 4),xmm4) + vshufps(imm(0x01), xmm12, xmm12, xmm9)//e1 + vfmadd231ps(xmm4, xmm3, xmm12) + vmovss(xmm12,mem(rcx,rdi,4))//e0 + + lea(mem(rcx, rsi, 1), rcx) + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rdi, 2),xmm6) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm6, xmm3, xmm7) + vmovsd(xmm5, mem(rcx)) //a1b1 + vmovsd(xmm7, mem(rcx, rdi, 2)) //c1d1 + + vmovss( mem(rcx, rdi, 4),xmm4) + vfmadd231ps(xmm4, xmm3, xmm9) + vmovss(xmm9,mem(rcx,rdi,4))//e1 + + jmp(.SDONE) // jump to end. + + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx))//a0a1 + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//c1d1 + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm1, mem(rcx, rdi, 2)) //c0d0 + vshufps(imm(0x01), xmm12, xmm12, xmm9)//e1 + vmovss(xmm12,mem(rcx,rdi,4))//e0 + + lea(mem(rcx, rsi, 1), rcx) + vmovsd(xmm5, mem(rcx)) //a1b1 + vmovsd(xmm7, mem(rcx, rdi, 2)) //c1d1 + vmovss(xmm9,mem(rcx,rdi,4))//e1 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + // now avoid loading C if beta == 0 + + vxorps(xmm0,xmm0,xmm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + + + label(.SROWSTORED) + + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4)//c*beta+(a0a1) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rdi, 2),xmm6) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//c1d1 + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm1, mem(rcx, rdi, 2)) //c0d0 + + lea(mem(rcx, rsi, 1), rcx) + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rdi, 2),xmm6) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm6, xmm3, xmm7) + vmovsd(xmm5, mem(rcx)) //a1b1 + vmovsd(xmm7, mem(rcx, rdi, 2)) //c1d1 + + jmp(.SDONE) // jump to end. + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vshufpd(imm(0x01), xmm1, xmm1, xmm7)//c1d1 + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm1, mem(rcx, rdi, 2)) //c0d0 + lea(mem(rcx, rsi, 1), rcx) + vmovsd(xmm5, mem(rcx)) //a1b1 + vmovsd(xmm7, mem(rcx, rdi, 2)) //c1d1 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vmovsd(mem(rcx),xmm4) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vfmadd231ps(xmm4, xmm3, xmm0) + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovss(mem(rcx,rdi,2),xmm4) + vshufps(imm(0x01), xmm8, xmm8, xmm9)//c1 + vfmadd231ps(xmm4, xmm3, xmm8) + vmovss(xmm8,mem(rcx,rdi,2))//c0 + + lea(mem(rcx, rsi, 1), rcx) + vmovsd(mem(rcx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovsd(xmm5, mem(rcx)) //a1b1 + + vmovss(mem(rcx,rdi,2),xmm4) + vfmadd231ps(xmm4, xmm3, xmm9) + vmovss(xmm9,mem(rcx,rdi,2))//c1 + + jmp(.SDONE) // jump to end. + + + label(.SBETAZERO) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vmovsd(xmm0, mem(rcx)) //a0b0 + vshufps(imm(0x01), xmm8, xmm8, xmm9)//c1 + vmovss(xmm8,mem(rcx,rdi,2))//c0 + lea(mem(rcx, rsi, 1), rcx) + vmovsd(xmm5, mem(rcx)) //a1b1 + vmovss(xmm9,mem(rcx,rdi,2))//c1 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vmovsd(mem(rcx),xmm4) + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vfmadd231ps(xmm4, xmm3, xmm0) + vmovsd(xmm0, mem(rcx)) //a0b0 + lea(mem(rcx, rsi, 1), rcx) + vmovsd(mem(rcx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovsd(xmm5, mem(rcx)) //a1b1 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vshufpd(imm(0x01), xmm0, xmm0, xmm5)//a1b1 + vmovsd(xmm0, mem(rcx)) //a0b0 + vmovsd(xmm5, mem(rcx, rsi, 1)) //a1b1 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + begin_asm() + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm4) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vshufps(imm(0x01), xmm4, xmm4, xmm9)//c1 + vmovss(mem(rcx),xmm6) + vfmadd231ps(xmm6, xmm3, xmm4) + vmovss(xmm4,mem(rcx))//c0 + vmovss(mem(rcx,rsi,1),xmm6) + vfmadd231ps(xmm6, xmm3, xmm9) + vmovss(xmm9,mem(rcx,rsi,1))//c1 + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + vmovsd(xmm4, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vshufps(imm(0x01), xmm4, xmm4, xmm9)//c1 + vmovss(xmm4,mem(rcx))//c0 + vmovss(xmm9,mem(rcx,rsi,1))//c1 + + label(.SDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +// ----------------------------------------------------------------------------- + +// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel. +// However, at least one other subconfiguration (zen) uses this kernel set, so +// we need to be able to call a set of "?x1" kernels that we know will actually +// exist regardless of which subconfiguration these kernels were used by. Thus, +// the compromise employed here is to inline the reference kernel so it gets +// compiled as part of the zen kernel set, and hence can unconditionally be +// called by other kernels within that kernel set. +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mdim ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < mdim; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + /* for ( dim_t j = 0; j < 1; ++j ) */ \ + { \ + ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \ + ctype* restrict bj = b /*[ j*cs_b ]*/ ; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( float, s, gemmsup_r_zen_ref_6x1, 6 ) +GENTFUNC( float, s, gemmsup_r_zen_ref_5x1, 5 ) +GENTFUNC( float, s, gemmsup_r_zen_ref_4x1, 4 ) +GENTFUNC( float, s, gemmsup_r_zen_ref_3x1, 3 ) +GENTFUNC( float, s, gemmsup_r_zen_ref_2x1, 2 ) +GENTFUNC( float, s, gemmsup_r_zen_ref_1x1, 1 ) + diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c new file mode 100644 index 0000000000..41dbbd699e --- /dev/null +++ b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -0,0 +1,2395 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ +void bli_sgemmsup_rv_zen_asm_6x16m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 16; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if (n_left ) + { + float* cij = c; + float* bj = b; + float* ai = a; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 1 == n_left ) + { + dim_t ps_a0 = bli_auxinfo_ps_a( data ); + if ( ps_a0 == 6 * rs_a0 ) + { + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + else + { + const dim_t mr = 6; + + // Since A is packed into row panels, we must use a loop over + // gemv. + dim_t m_iter = ( m0 + mr - 1 ) / mr; + dim_t m_left = m0 % mr; + + float* restrict ai_ii = ai; + float* restrict cij_ii = cij; + + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) + { + dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) + ? mr : m_left ); + + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, + beta, cij_ii, rs_c0, cntx, NULL + ); + cij_ii += mr*rs_c0; ai_ii += ps_a0; + } + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(dt) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X16I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm13) + vmovups(ymm13, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm15) + vmovups(ymm15, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) + + /*|-----------------| |-----|----| + | | | | 8x4 | 8x2| + | 4x8 | 4x8 | | | | + | | | |-----|----| + |-----------------| | 8x4 | 8x2| + | 2x8 | 2x8 | | | | + |------------------ |----------|*/ + + /****6x16 tile is transposed and saved in col major as 6x16*****/ + /****top left tile 4x8 transposed to top left tile 8x4**********/ + vunpcklps(ymm6, ymm4, ymm0)//a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1)//c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2)//a1b1c0d0 a5b5c4d4 + vblendps(imm(0xcc), ymm2, ymm0, ymm0)//a0b0c0d0 a4b4c4d4 + vblendps(imm(0x33), ymm2, ymm1, ymm1)//a1b1c1d1 a5b5c5d5 + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /***bottom left tile - 2x8 is transposed to top right tile 8x2**********/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c + + vmovlpd(mem(rax), xmm1, xmm1) + vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + /***top right tile 4x8 is transposed to bottom left tile 8x4**********/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + /*** bottom right 2x8 is transposed to bottom right tile 8x2*******/ + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c + + vmovlpd(mem(rax), xmm1, xmm1) + vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORBZ) + /****6x16 tile going to save into 16x6 tile in C*****/ + /******************top left tile 8x4***************************/ + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + /******************top right tile 8x2***************************/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + lea(mem(rdx, rsi, 1), rdx) + lea(mem(rdx, rsi, 4), rdx) // rdx += 8*cs_c + + /******************bottom left tile 8x4***************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + /******************bottom right tile 8x2***************************/ + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X16I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 16; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x16, + bli_sgemmsup_rv_zen_asm_2x16, + bli_sgemmsup_rv_zen_asm_3x16, + bli_sgemmsup_rv_zen_asm_4x16, + bli_sgemmsup_rv_zen_asm_5x16 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + + } +} + +void bli_sgemmsup_rv_zen_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower + vxorps(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us down. + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x8 tile is transposed and saved in col major as 8x6*****/ + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vpermilps(imm(0xe),xmm0,xmm5) + vpermilps(imm(0xe),xmm2,xmm6) + vfmadd231ps(mem(rdx), xmm3, xmm0) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vfmadd231ps(mem(rdx), xmm3, xmm5) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vpermilps(imm(0xe),xmm0,xmm5) + vpermilps(imm(0xe),xmm2,xmm6) + vfmadd231ps(mem(rdx), xmm3, xmm0) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vfmadd231ps(mem(rdx), xmm3, xmm5) + vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) + vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + vmovups(ymm12, mem(rcx)) + add(rdi, rcx) + vmovups(ymm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + /******************top right tile 8x2***************************/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x8, + bli_sgemmsup_rv_zen_asm_2x8, + bli_sgemmsup_rv_zen_asm_3x8, + bli_sgemmsup_rv_zen_asm_4x8, + bli_sgemmsup_rv_zen_asm_5x8 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} + +void bli_sgemmsup_rv_zen_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovups(xmm12, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm14) + vmovups(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x4 tile is transposed and saved in col major as 4x6*****/ + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vfmadd231ps(mem(rcx), xmm3, xmm1) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vfmadd231ps(mem(rcx), xmm3, xmm1) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + + vunpcklps(xmm14, xmm12, xmm0) + vpermilps(imm(0x4e), xmm0, xmm5) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + + lea(mem(rdx, rsi, 1), rdx) + vunpckhps(xmm14, xmm12, xmm0) + vpermilps(imm(0x4e), xmm0, xmm5) + vfmadd231ps(mem(rdx), xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + + lea(mem(rdx, rsi, 1), rdx) + vfmadd231ps(mem(rdx), xmm3, xmm5) + vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + vmovups(xmm12, mem(rcx)) + add(rdi, rcx) + vmovups(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 1), rdx) + vunpckhps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X4I) // iterate again if ii != 0. + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x4, + bli_sgemmsup_rv_zen_asm_2x4, + bli_sgemmsup_rv_zen_asm_3x4, + bli_sgemmsup_rv_zen_asm_4x4, + bli_sgemmsup_rv_zen_asm_5x4 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} + +void bli_sgemmsup_rv_zen_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovlpd(xmm6, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovlpd(xmm8, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovlpd(xmm10, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm14) + vmovlpd(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x2 tile is transposed and saved in col major as 2x6*****/ + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vshufps(imm(0x44), xmm1, xmm0, xmm2) //01-00-01-00 + vshufps(imm(0xee), xmm1, xmm0, xmm4) //11-10-11-10 + + vfmadd231ps(mem(rcx), xmm3, xmm2) + vmovupd(xmm2, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) // store ( gamma01..gamma31 ) + + vunpcklps(xmm14, xmm12, xmm0)//eof0e1f1 + vpermilps(imm(0x4e),xmm0,xmm5) + vmovq(mem(rdx), xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx), xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm8, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm10, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vshufps(imm(0x44), xmm1, xmm0, xmm2) //01-00-01-00 + vshufps(imm(0xee), xmm1, xmm0, xmm4) //11-10-11-10 + + vmovupd(xmm2, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vmovupd(xmm4, mem(rcx)) // store ( gamma01..gamma31 ) + + vunpcklps(xmm14, xmm12, xmm0)//eof0e1f1 + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X2I) // iterate again if ii != 0. + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x2, + bli_sgemmsup_rv_zen_asm_2x2, + bli_sgemmsup_rv_zen_asm_3x2, + bli_sgemmsup_rv_zen_asm_4x2, + bli_sgemmsup_rv_zen_asm_5x2 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c new file mode 100644 index 0000000000..a7ab770cb2 --- /dev/null +++ b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c @@ -0,0 +1,3887 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +//GEMMSUP_KER_PROT( float, f, semmsup_r_zen_ref ) + +void bli_sgemmsup_rv_zen_asm_6x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + if ( m_left ) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + sgemmsup_ker_ft ker_fp1 = NULL; + sgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m0 == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_sgemmsup_rv_zen_asm_4x16n; + ker_fp2 = bli_sgemmsup_rv_zen_asm_3x16n; + } + else if ( m0 == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_sgemmsup_rv_zen_asm_4x16n; + ker_fp2 = bli_sgemmsup_rv_zen_asm_4x16n; + } + else // if ( m0 == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_sgemmsup_rv_zen_asm_4x16n; + ker_fp2 = bli_sgemmsup_rv_zen_asm_5x16n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x16n, + bli_sgemmsup_rv_zen_asm_2x16n, + bli_sgemmsup_rv_zen_asm_3x16n, + bli_sgemmsup_rv_zen_asm_4x16n, + bli_sgemmsup_rv_zen_asm_5x16n + }; + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 /16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X16J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done pre-fetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rbx, 11*4)) // pre-fetch line of next upanel of b + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm13) + vmovups(ymm13, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm15) + vmovups(ymm15, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /***bottom left tile - 2x8 is transposed to top right tile 8x2**********/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c + + vmovlpd(mem(rax), xmm1, xmm1) + vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + /***top right tile 4x8 is transposed to bottom left tile 8x4**********/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + /*** bottom right 2x8 is transposed to bottom right tile 8x2*******/ + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c + + vmovlpd(mem(rax), xmm1, xmm1) + vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx, rsi, 8)) + add(rdi, rcx) + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx, rsi, 8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /****6x16 tile going to save into 16x6 tile in C*****/ + /******************top left tile 8x4***************************/ + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + /******************top right tile 8x2***************************/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + lea(mem(rdx, rsi, 1), rdx) + lea(mem(rdx, rsi, 4), rdx) // rdx += 8*cs_c + + /******************bottom left tile 8x4***************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + /******************bottom right tile 8x2***************************/ + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + label(.SDONE) + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(4*16), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X16J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + n_iter * ps_b; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + bli_sgemmsup_rv_zen_asm_6x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_6x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_6x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_zen_ref_6x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_sgemmsup_rv_zen_asm_5x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 /16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP6X16J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + //lea(mem(rbx, r10, 8), rdx) // use rdx for prefetching b. + //lea(mem(rdx, r10, 8), rdx) // rdx = b + 16*rs_b; + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm13) + vmovups(ymm13, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm12, xmm0)//e0-e3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm12, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + vextractf128(imm(0x0), ymm13, xmm0)//e0-e3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm13, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm12, xmm0)//e0-e3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm12, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + vextractf128(imm(0x0), ymm13, xmm0)//e0-e3 + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm13, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0,xmm0, xmm1) + vshufps(imm(0x02), xmm0,xmm0, xmm2) + vshufps(imm(0x03), xmm0,xmm0, xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + label(.SDONE) + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + //add(imm(4*16), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP6X16J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 5; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + n_iter * ps_b; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_5x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_5x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_5x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 1 + const dim_t nr_cur = 1; + + bli_sgemmsup_r_zen_ref_5x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + + } +} + +void bli_sgemmsup_rv_zen_asm_4x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 /16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP4X16J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + + mov(var(a), rax) // load address of a. + mov(r14, rbx) // reset rbx to current upanel of b. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi,8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi,8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1) //c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + label(.SDONE) + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + //add(imm(4*16), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP4X16J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + + const dim_t mr_cur = 4; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + n_iter * ps_b; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_4x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_4x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_4x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + bli_sgemmsup_r_zen_ref_4x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + + } +} + +void bli_sgemmsup_rv_zen_asm_3x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 /16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP4X16J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(a), rax) // load address of a. + + mov(r14, rbx) // reset rbx to current upanel of b. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm8, xmm0)//c0-c3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm11) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm8, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + /********************************************/ + vextractf128(imm(0x0), ymm9, xmm0)//c0-c3 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm9, xmm0)//e4-e7 + vmovss(mem(rdx),xmm4) + vmovss(mem(rdx, rsi, 1),xmm6) + vmovss(mem(rdx, rsi, 2),xmm8) + vmovss(mem(rdx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm8, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi,8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /********************************************/ + vextractf128(imm(0x0), ymm8, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm8, xmm0)//c4-c7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + /*********************************************/ + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + /********************************************/ + vextractf128(imm(0x0), ymm9, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + + vextractf128(imm(0x1), ymm9, xmm0)//c4-c7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rdx)) + vmovss(xmm1, mem(rdx, rsi, 1)) + vmovss(xmm2, mem(rdx, rsi, 2)) + vmovss(xmm14, mem(rdx, rax, 1)) + + label(.SDONE) + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + lea(mem(r12, rsi, 8), r12) + //add(imm(4*16), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP4X16J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + n_iter * ps_b; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_3x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + bli_sgemmsup_r_zen_ref_3x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + + } +} + +void bli_sgemmsup_rv_zen_asm_2x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 /16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP2X16J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vperm2f128(imm(0x01),ymm0,ymm0,ymm11) + vperm2f128(imm(0x01),ymm2,ymm2,ymm12) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm2) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 + vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 + vmovsd(mem(rcx),xmm4) + vmovsd(mem(rcx, rsi, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm11) + vfmadd231ps(xmm6, xmm3, xmm1) + vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(mem(rcx, rsi, 2),xmm4) + vmovsd(mem(rcx, rax, 1),xmm6) + vfmadd231ps(xmm4, xmm3, xmm12) + vfmadd231ps(xmm6, xmm3, xmm10) + vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi,8)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm6, ymm4, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vunpcklps(ymm7, ymm5, ymm0) //a0b0a1b1 a2b2a3b3 + vunpckhps(ymm7, ymm5, ymm2) //a2b2a3b3 a6b6a7b7 + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vperm2f128(imm(0x01),ymm0,ymm0,ymm0) + vperm2f128(imm(0x01),ymm2,ymm2,ymm2) + vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a2b2 + vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 + vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) + vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) + vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) + vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) + + label(.SDONE) + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + lea(mem(r12, rsi, 8), r12) + //add(imm(4*16), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP2X16J) // iterate again if jj != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + n_iter * ps_b; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + bli_sgemmsup_r_zen_ref_2x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_sgemmsup_rv_zen_asm_1x16n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 /16; + uint64_t n_left = n0 % 16; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + //lea(mem(, r11, 4), r11) // cs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + mov(var(n_iter), r11) // jj = n_iter; + + label(.SLOOP1X16J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rbx, 11*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi,8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vextractf128(imm(0x0), ymm4, xmm0)//c0-c3 + vmovss(mem(rcx),xmm7) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm11) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm7, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + vextractf128(imm(0x1), ymm4, xmm0)//e4-e7 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm8) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vextractf128(imm(0x0), ymm5, xmm0)//c0-c3 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm11) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e0 + vfmadd231ps(xmm6, xmm3, xmm1)//e1 + vfmadd231ps(xmm11, xmm3, xmm2)//e2 + vfmadd231ps(xmm10, xmm3, xmm14)//e3 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + vextractf128(imm(0x1), ymm5, xmm0)//e4-e7 + vmovss(mem(rcx),xmm4) + vmovss(mem(rcx, rsi, 1),xmm6) + vmovss(mem(rcx, rsi, 2),xmm8) + vmovss(mem(rcx, rax, 1),xmm10) + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vfmadd231ps(xmm4, xmm3, xmm0)//e4 + vfmadd231ps(xmm6, xmm3, xmm1)//e5 + vfmadd231ps(xmm8, xmm3, xmm2)//e6 + vfmadd231ps(xmm10, xmm3, xmm14)//e7 + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi,8)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vextractf128(imm(0x0), ymm4, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm4, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + vextractf128(imm(0x0), ymm5, xmm0)//c0-c3 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c + vextractf128(imm(0x1), ymm5, xmm0)//e4-e7 + vshufps(imm(0x01), xmm0, xmm0,xmm1) + vshufps(imm(0x02), xmm0, xmm0,xmm2) + vshufps(imm(0x03), xmm0, xmm0,xmm14) + vmovss(xmm0, mem(rcx)) + vmovss(xmm1, mem(rcx, rsi, 1)) + vmovss(xmm2, mem(rcx, rsi, 2)) + vmovss(xmm14, mem(rcx, rax, 1)) + + label(.SDONE) + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + lea(mem(r12, rsi, 8), r12) + //add(imm(4*16), r14) // b_jj = r14 += 8*cs_b + mov(var(ps_b4), rbx) // load ps_b4 + lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b4 + + dec(r11) // jj -= 1; + jne(.SLOOP1X16J) // iterate again if jj != 0. + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + float* restrict cij = c + j_edge*cs_c; + float* restrict ai = a; + float* restrict bj = b + n_iter * ps_b; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_sgemmsup_r_zen_ref_1x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + + } +} + diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 110c6f17ce..c9651554d7 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,6 +33,13 @@ */ +// -- level-1m -- +PACKM_KER_PROT(double, d, packm_8xk_gen_zen) +PACKM_KER_PROT(double, d, packm_6xk_gen_zen) +PACKM_KER_PROT(double, d, packm_8xk_nn_zen) +PACKM_KER_PROT(double, d, packm_6xk_nn_zen) + + // -- level-1v -- // amaxv (intrinsics) @@ -42,17 +50,17 @@ AMAXV_KER_PROT( double, d, amaxv_zen_int ) AXPYV_KER_PROT( float, s, axpyv_zen_int ) AXPYV_KER_PROT( double, d, axpyv_zen_int ) - // axpyv (intrinsics unrolled x10) - AXPYV_KER_PROT( float, s, axpyv_zen_int10 ) - AXPYV_KER_PROT( double, d, axpyv_zen_int10 ) +// axpyv (intrinsics unrolled x10) +AXPYV_KER_PROT( float, s, axpyv_zen_int10 ) +AXPYV_KER_PROT( double, d, axpyv_zen_int10 ) // dotv (intrinsics) DOTV_KER_PROT( float, s, dotv_zen_int ) DOTV_KER_PROT( double, d, dotv_zen_int ) - // dotv (intrinsics, unrolled x10) - DOTV_KER_PROT( float, s, dotv_zen_int10 ) - DOTV_KER_PROT( double, d, dotv_zen_int10 ) +// dotv (intrinsics, unrolled x10) +DOTV_KER_PROT( float, s, dotv_zen_int10 ) +DOTV_KER_PROT( double, d, dotv_zen_int10 ) // dotxv (intrinsics) DOTXV_KER_PROT( float, s, dotxv_zen_int ) @@ -62,17 +70,144 @@ DOTXV_KER_PROT( double, d, dotxv_zen_int ) SCALV_KER_PROT( float, s, scalv_zen_int ) SCALV_KER_PROT( double, d, scalv_zen_int ) - // scalv (intrinsics unrolled x10) - SCALV_KER_PROT( float, s, scalv_zen_int10 ) - SCALV_KER_PROT( double, d, scalv_zen_int10 ) +// scalv (intrinsics unrolled x10) +SCALV_KER_PROT( float, s, scalv_zen_int10 ) +SCALV_KER_PROT( double, d, scalv_zen_int10 ) +SCALV_KER_PROT( scomplex, c, scalv_zen_int10 ) + +// swapv (intrinsics) +SWAPV_KER_PROT(float, s, swapv_zen_int8 ) +SWAPV_KER_PROT(double, d, swapv_zen_int8 ) + +// copyv (intrinsics) +COPYV_KER_PROT( float, s, copyv_zen_int ) +COPYV_KER_PROT( double, d, copyv_zen_int ) + +// +SETV_KER_PROT(float, s, setv_zen_int) +SETV_KER_PROT(double, d, setv_zen_int) + +// swapv (intrinsics) +SWAPV_KER_PROT(float, s, swapv_zen_int8 ) +SWAPV_KER_PROT(double, d, swapv_zen_int8 ) + // -- level-1f -- // axpyf (intrinsics) AXPYF_KER_PROT( float, s, axpyf_zen_int_8 ) AXPYF_KER_PROT( double, d, axpyf_zen_int_8 ) +AXPYF_KER_PROT( float, s, axpyf_zen_int_5 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int_5 ) + +AXPYF_KER_PROT( double, d, axpyf_zen_int_16x4 ) +AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_4 ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +// -- level-3 sup -------------------------------------------------------------- + +// semmsup_rv + +//GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x16 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x16 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x8 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x8 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x4 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x2 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x2 ) + +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref_6x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref_5x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref_4x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref_3x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref_2x1 ) +GEMMSUP_KER_PROT( float, s, gemmsup_r_zen_ref_1x1 ) + +// gemmsup_rv (mkernel in m dim) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x16m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x8m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4m ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x2m ) +// gemmsup_rv (mkernel in n dim) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x16n ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x16n ) + +// gemmsup_rd +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x8) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x16) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x8) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x16) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x4) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x4) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x4) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x2) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_3x2) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x2) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x2) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x16m) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x8m) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x4m) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x2m) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_6x16n) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_3x16n) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x16n) +GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x16n) + +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x8m ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x4m ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x2m ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_2x8 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_1x8 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_2x4 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_1x4 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_2x2 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_1x2 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x4m ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2m ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x2 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x2 ) + +// gemmsup_rv (mkernel in n dim) + + +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x8n ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_2x8n ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_1x8n ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x4 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x2 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x4n ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) + diff --git a/kernels/zen2/bli_kernels_zen2.h b/kernels/zen2/bli_kernels_zen2.h new file mode 100644 index 0000000000..db3bf2c26c --- /dev/null +++ b/kernels/zen2/bli_kernels_zen2.h @@ -0,0 +1,40 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// -- level-1f -- + +AXPYF_KER_PROT( float, s, axpyf_zen_int_5 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int_5 ) + diff --git a/kernels/zen3/.gitignore b/kernels/zen3/.gitignore new file mode 100644 index 0000000000..5e7d2734cf --- /dev/null +++ b/kernels/zen3/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/mpi_test/Makefile b/mpi_test/Makefile index 8bf871b997..00ca01e47d 100644 --- a/mpi_test/Makefile +++ b/mpi_test/Makefile @@ -134,7 +134,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) diff --git a/ref_kernels/1/bli_amaxv_ref.c b/ref_kernels/1/bli_amaxv_ref.c index ca584213ef..169180f3b1 100644 --- a/ref_kernels/1/bli_amaxv_ref.c +++ b/ref_kernels/1/bli_amaxv_ref.c @@ -97,7 +97,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ encountered, then treat it the same as if it were a valid value that was smaller than any previously seen. This behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ + if ( abs_chi1_max < abs_chi1 || ( bli_isnan( abs_chi1 ) && !bli_isnan( abs_chi1_max ) ) ) \ { \ abs_chi1_max = abs_chi1; \ i_max_l = i; \ @@ -129,7 +129,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ encountered, then treat it the same as if it were a valid value that was smaller than any previously seen. This behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ + if ( abs_chi1_max < abs_chi1 || ( bli_isnan( abs_chi1 ) && !bli_isnan( abs_chi1_max ) ) ) \ { \ abs_chi1_max = abs_chi1; \ i_max_l = i; \ diff --git a/ref_kernels/1m/bli_packm_cxk_1er_ref.c b/ref_kernels/1m/bli_packm_cxk_1er_ref.c index e26381e8a0..03ec46d147 100644 --- a/ref_kernels/1m/bli_packm_cxk_1er_ref.c +++ b/ref_kernels/1m/bli_packm_cxk_1er_ref.c @@ -44,9 +44,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -261,9 +261,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -494,9 +494,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -743,9 +743,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1008,9 +1008,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1289,9 +1289,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1586,9 +1586,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1899,9 +1899,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ diff --git a/ref_kernels/1m/bli_packm_cxk_3mis_ref.c b/ref_kernels/1m/bli_packm_cxk_3mis_ref.c deleted file mode 100644 index 8768b78ac5..0000000000 --- a/ref_kernels/1m/bli_packm_cxk_3mis_ref.c +++ /dev/null @@ -1,1954 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_2xk_3mis, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_4xk_3mis, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_6xk_3mis, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_8xk_3mis, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_10xk_3mis, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_12xk_3mis, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_14xk_3mis, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_16xk_3mis, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/1m/bli_packm_cxk_4mi_ref.c b/ref_kernels/1m/bli_packm_cxk_4mi_ref.c deleted file mode 100644 index ab5375b3db..0000000000 --- a/ref_kernels/1m/bli_packm_cxk_4mi_ref.c +++ /dev/null @@ -1,1450 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_2xk_4mi, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_4xk_4mi, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_6xk_4mi, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_8xk_4mi, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_10xk_4mi, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_12xk_4mi, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_14xk_4mi, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_16xk_4mi, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/1m/bli_packm_cxk_bb_ref.c b/ref_kernels/1m/bli_packm_cxk_bb_ref.c new file mode 100644 index 0000000000..e7498a735d --- /dev/null +++ b/ref_kernels/1m/bli_packm_cxk_bb_ref.c @@ -0,0 +1,656 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// -- 6xk, duplication factor 2 ------------------------------------------------ + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mnr, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t cdim, \ + dim_t n, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ + cntx_t* restrict cntx \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict alpha1 = a; \ + ctype* restrict pi1 = p; \ +\ + const dim_t dfac = 2; \ +\ + /* Handle the packing of B (column panel schemas) separately from packing + of A (row panel schemas). */ \ + if ( bli_is_col_packed( schema ) ) \ + { \ + if ( cdim == mnr ) \ + { \ + if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 11) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 11) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + else /* if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) */ \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 11) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 11) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + } \ + else /* if ( cdim < mnr ) */ \ + { \ + PASTEMAC(ch,scal2bbs_mxn) \ + ( \ + conja, \ + cdim, \ + n, \ + kappa, \ + a, inca, lda, \ + p, dfac, ldp \ + ); \ +\ + /* if ( cdim < mnr ) */ \ + { \ + const dim_t i = cdim; \ + const dim_t m_edge = mnr - cdim; \ + const dim_t n_edge = n_max; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (i )*dfac; \ +\ + PASTEMAC(ch,set0bbs_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, dfac, ldp \ + ); \ + } \ + } \ +\ + if ( n < n_max ) \ + { \ + const dim_t j = n; \ + const dim_t m_edge = mnr; \ + const dim_t n_edge = n_max - n; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (j )*ldp; \ +\ + PASTEMAC(ch,set0bbs_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, dfac, ldp \ + ); \ + } \ + } \ + else /* if ( bli_is_row_packed( schema ) ) */ \ + { \ + if ( cdim == mnr ) \ + { \ + if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + else /* if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) */ \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + } \ + else /* if ( cdim < mnr ) */ \ + { \ + PASTEMAC(ch,scal2s_mxn) \ + ( \ + conja, \ + cdim, \ + n, \ + kappa, \ + a, inca, lda, \ + p, 1, ldp \ + ); \ +\ + /* if ( cdim < mnr ) */ \ + { \ + const dim_t i = cdim; \ + const dim_t m_edge = mnr - cdim; \ + const dim_t n_edge = n_max; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ + if ( n < n_max ) \ + { \ + const dim_t j = n; \ + const dim_t m_edge = mnr; \ + const dim_t n_edge = n_max - n; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC3( packm_6xk_bb2, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// -- 6xk, duplication factor 4 ------------------------------------------------ + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mnr, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t cdim, \ + dim_t n, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ + cntx_t* restrict cntx \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict alpha1 = a; \ + ctype* restrict pi1 = p; \ +\ + const dim_t dfac = 4; \ +\ + /* Handle the packing of B (column panel schemas) separately from packing + of A (row panel schemas). */ \ + if ( bli_is_col_packed( schema ) ) \ + { \ + if ( cdim == mnr ) \ + { \ + if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 11) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 12) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 13) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 14) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 15) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 16) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 17) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 18) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 19) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 20) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 21) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 22) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 23) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 11) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 12) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 13) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 14) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 15) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 16) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 17) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 18) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 19) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 20) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 21) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 22) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 23) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + else /* if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) */ \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 11) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 12) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 13) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 14) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 15) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 16) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 17) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 18) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 19) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 20) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 21) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 22) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 23) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 10) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 11) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 12) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 13) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 14) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 15) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 16) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 17) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 18) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 19) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 20) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 21) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 22) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 23) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + } \ + else /* if ( cdim < mnr ) */ \ + { \ + PASTEMAC(ch,scal2bbs_mxn) \ + ( \ + conja, \ + cdim, \ + n, \ + kappa, \ + a, inca, lda, \ + p, dfac, ldp \ + ); \ +\ + /* if ( cdim < mnr ) */ \ + { \ + const dim_t i = cdim; \ + const dim_t m_edge = mnr - cdim; \ + const dim_t n_edge = n_max; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (i )*dfac; \ +\ + PASTEMAC(ch,set0bbs_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, dfac, ldp \ + ); \ + } \ + } \ +\ + if ( n < n_max ) \ + { \ + const dim_t j = n; \ + const dim_t m_edge = mnr; \ + const dim_t n_edge = n_max - n; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (j )*ldp; \ +\ + PASTEMAC(ch,set0bbs_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, dfac, ldp \ + ); \ + } \ + } \ + else /* if ( bli_is_row_packed( schema ) ) */ \ + { \ + if ( cdim == mnr ) \ + { \ + if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + else /* if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) */ \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else /* if ( bli_is_noconj( conja ) ) */ \ + { \ + for ( dim_t k = n; k != 0; --k ) \ + { \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + } \ + else /* if ( cdim < mnr ) */ \ + { \ + PASTEMAC(ch,scal2s_mxn) \ + ( \ + conja, \ + cdim, \ + n, \ + kappa, \ + a, inca, lda, \ + p, 1, ldp \ + ); \ +\ + /* if ( cdim < mnr ) */ \ + { \ + const dim_t i = cdim; \ + const dim_t m_edge = mnr - cdim; \ + const dim_t n_edge = n_max; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ + if ( n < n_max ) \ + { \ + const dim_t j = n; \ + const dim_t m_edge = mnr; \ + const dim_t n_edge = n_max - n; \ + ctype* restrict p_cast = p; \ + ctype* restrict p_edge = p_cast + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC3( packm_6xk_bb4, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + diff --git a/ref_kernels/1m/bli_packm_cxk_ref.c b/ref_kernels/1m/bli_packm_cxk_ref.c index 33a1e9b437..c98f1b2503 100644 --- a/ref_kernels/1m/bli_packm_cxk_ref.c +++ b/ref_kernels/1m/bli_packm_cxk_ref.c @@ -40,12 +40,13 @@ void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -188,12 +189,13 @@ INSERT_GENTFUNC_BASIC3( packm_2xk, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -344,12 +346,13 @@ INSERT_GENTFUNC_BASIC3( packm_3xk, 3, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -498,12 +501,13 @@ INSERT_GENTFUNC_BASIC3( packm_4xk, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -641,12 +645,13 @@ INSERT_GENTFUNC_BASIC3( packm_6xk, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -819,12 +824,13 @@ INSERT_GENTFUNC_BASIC3( packm_8xk, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -978,12 +984,13 @@ INSERT_GENTFUNC_BASIC3( packm_10xk, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1145,12 +1152,13 @@ INSERT_GENTFUNC_BASIC3( packm_12xk, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1320,12 +1328,13 @@ INSERT_GENTFUNC_BASIC3( packm_14xk, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ @@ -1503,12 +1512,13 @@ INSERT_GENTFUNC_BASIC3( packm_16xk, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) void PASTEMAC3(ch,opname,arch,suf) \ ( \ conj_t conja, \ + pack_t schema, \ dim_t cdim, \ dim_t n, \ dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t inca, inc_t lda, \ + ctype* restrict p, inc_t ldp, \ cntx_t* restrict cntx \ ) \ { \ diff --git a/ref_kernels/1m/bli_packm_cxk_rih_ref.c b/ref_kernels/1m/bli_packm_cxk_rih_ref.c deleted file mode 100644 index e0e626d0f5..0000000000 --- a/ref_kernels/1m/bli_packm_cxk_rih_ref.c +++ /dev/null @@ -1,2498 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_2xk_rih, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_4xk_rih, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_6xk_rih, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_8xk_rih, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_10xk_rih, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_12xk_rih, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_14xk_rih, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - void* restrict kappa, \ - void* restrict a, inc_t inca, inc_t lda, \ - void* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +14*inca2), -*(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +15*inca2), -*(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_16xk_rih, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/3/bb/bli_gemmbb_ref.c b/ref_kernels/3/bb/bli_gemmbb_ref.c new file mode 100644 index 0000000000..4c75c064ce --- /dev/null +++ b/ref_kernels/3/bb/bli_gemmbb_ref.c @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// An implementation that indexes through B with the assumption that all +// elements were broadcast (duplicated) by a factor of NP/NR. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, \ + ctype* restrict b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ + const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ +\ + const inc_t cs_a = packmr; \ +\ + const inc_t rs_b = packnr; \ +\ + /* Assume that the degree of duplication is equal to packnr / nr. */ \ + const inc_t cs_b = packnr / nr; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = 1; \ + const inc_t cs_ab = mr; \ +\ + dim_t l, j, i; \ +\ + ctype ai; \ + ctype bj; \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( i = 0; i < m * n; ++i ) \ + { \ + PASTEMAC(ch,set0s)( *(ab + i) ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( l = 0; l < k; ++l ) \ + { \ + ctype* restrict abij = ab; \ +\ + /* In an optimized implementation, these two loops over MR and NR + are typically fully unrolled. */ \ + for ( j = 0; j < n; ++j ) \ + { \ + bj = *(b + j*cs_b); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + ai = *(a + i); \ +\ + PASTEMAC(ch,dots)( ai, bj, *abij ); \ +\ + abij += rs_ab; \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( i = 0; i < m * n; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, *(ab + i) ); \ + } \ +\ + /* If beta is zero, overwrite c with the scaled result in ab. Otherwise, + scale by beta and then add the scaled redult in ab. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,copys_mxn)( m, \ + n, \ + ab, rs_ab, cs_ab, \ + c, rs_c, cs_c ); \ + } \ + else \ + { \ + PASTEMAC(ch,xpbys_mxn)( m, \ + n, \ + ab, rs_ab, cs_ab, \ + beta, \ + c, rs_c, cs_c ); \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmbb, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + diff --git a/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c b/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c new file mode 100644 index 0000000000..dd4e1f153d --- /dev/null +++ b/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c @@ -0,0 +1,140 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// An implementation that indexes through B with the assumption that all +// elements were broadcast (duplicated) by a factor of NP/NR. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf, trsmkerid ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a1x, \ + ctype* restrict a11, \ + ctype* restrict bx1, \ + ctype* restrict b11, \ + ctype* restrict c11, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ +\ + const inc_t rs_b = packnr; \ +\ + /* Assume that the degree of duplication is equal to packnr / nr. */ \ + const inc_t cs_b = packnr / nr; \ +/* +printf( "bli_gemmtrsmbb_ref(): cs_b = %d\n", (int)cs_b ); \ +printf( "bli_gemmtrsmbb_ref(): k nr = %d %d\n", (int)k, (int)nr ); \ +*/ \ +\ + ctype* minus_one = PASTEMAC(ch,m1); \ +\ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + PASTECH(ch,trsm_ukr_ft) \ + trsm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, trsmkerid, cntx ); \ +\ +/* +PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b01", k, nr, \ + (double*)bx1, rs_b, cs_b, "%5.2f", "" ); \ +PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11", mr, 2*nr, \ + (double*)b11, rs_b, 1, "%5.2f", "" ); \ +*/ \ +\ + /* lower: b11 = alpha * b11 - a10 * b01; */ \ + /* upper: b11 = alpha * b11 - a12 * b21; */ \ + gemm_ukr \ + ( \ + mr, \ + nr, \ + k, \ + minus_one, \ + a1x, \ + bx1, \ + alpha, \ + b11, rs_b, cs_b, \ + data, \ + cntx \ + ); \ +/* +PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11 after gemm", mr, 2*nr, \ + (double*)b11, rs_b, 1, "%5.2f", "" ); \ +*/ \ +\ + /* b11 = inv(a11) * b11; + c11 = b11; */ \ + trsm_ukr \ + ( \ + a11, \ + b11, \ + c11, rs_c, cs_c, \ + data, \ + cntx \ + ); \ +/* +PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11 after trsm", mr, 2*nr, \ + (double*)b11, rs_b, 1, "%5.2f", "" ); \ +*/ \ +\ + /* Broadcast the elements of the updated b11 submatrix to their + duplicated neighbors. */ \ + PASTEMAC(ch,bcastbbs_mxn) \ + ( \ + mr, \ + nr, \ + b11, rs_b, cs_b \ + ); \ +\ +/* +PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b0111p_r after", k+3, 8, \ + ( double* )b01, 2*PASTEMAC(ch,packnr), 2, "%4.1f", "" ); \ +PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b0111p_i after", k+3, 8, \ + ( double* )b01 + 1, 2*PASTEMAC(ch,packnr), 2, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC3( gemmtrsmbb_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_L_UKR ) +INSERT_GENTFUNC_BASIC3( gemmtrsmbb_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_U_UKR ) + diff --git a/ref_kernels/3/bb/bli_trsmbb_ref.c b/ref_kernels/3/bb/bli_trsmbb_ref.c new file mode 100644 index 0000000000..e3f5500ccb --- /dev/null +++ b/ref_kernels/3/bb/bli_trsmbb_ref.c @@ -0,0 +1,214 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// An implementation that indexes through B with the assumption that all +// elements were broadcast (duplicated) by a factor of NP/NR. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf, diagop ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + ctype* restrict a, \ + ctype* restrict b, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ + const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ +\ + const dim_t m = mr; \ + const dim_t n = nr; \ +\ + const inc_t rs_a = 1; \ + const inc_t cs_a = packmr; \ +\ + const inc_t rs_b = packnr; \ +\ + /* Assume that the degree of duplication is equal to packnr / nr. */ \ + const inc_t cs_b = packnr / nr; \ +\ + dim_t iter, i, j, l; \ + dim_t n_behind; \ +\ + for ( iter = 0; iter < m; ++iter ) \ + { \ + i = iter; \ + n_behind = i; \ +\ + ctype* restrict alpha11 = a + (i )*rs_a + (i )*cs_a; \ + ctype* restrict a10t = a + (i )*rs_a + (0 )*cs_a; \ + ctype* restrict B0 = b + (0 )*rs_b + (0 )*cs_b; \ + ctype* restrict b1 = b + (i )*rs_b + (0 )*cs_b; \ +\ + /* b1 = b1 - a10t * B0; */ \ + /* b1 = b1 / alpha11; */ \ + for ( j = 0; j < n; ++j ) \ + { \ + ctype* restrict b01 = B0 + (0 )*rs_b + (j )*cs_b; \ + ctype* restrict beta11 = b1 + (0 )*rs_b + (j )*cs_b; \ + ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ + ctype beta11c = *beta11; \ + ctype rho11; \ +\ + /* beta11 = beta11 - a10t * b01; */ \ + PASTEMAC(ch,set0s)( rho11 ); \ + for ( l = 0; l < n_behind; ++l ) \ + { \ + ctype* restrict alpha10 = a10t + (l )*cs_a; \ + ctype* restrict beta01 = b01 + (l )*rs_b; \ +\ + PASTEMAC(ch,axpys)( *alpha10, *beta01, rho11 ); \ + } \ + PASTEMAC(ch,subs)( rho11, beta11c ); \ +\ + /* beta11 = beta11 / alpha11; */ \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,scals)( *alpha11, beta11c ); \ +\ + /* Output final result to matrix c. */ \ + PASTEMAC(ch,copys)( beta11c, *gamma11 ); \ +\ + /* Store the local value back to b11. */ \ + PASTEMAC(ch,copys)( beta11c, *beta11 ); \ + } \ + } \ +} + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +INSERT_GENTFUNC_BASIC3( trsmbb_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, scals ) +#else +INSERT_GENTFUNC_BASIC3( trsmbb_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, invscals ) +#endif + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf, diagop ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + ctype* restrict a, \ + ctype* restrict b, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ + const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ +\ + const dim_t m = mr; \ + const dim_t n = nr; \ +\ + const inc_t rs_a = 1; \ + const inc_t cs_a = packmr; \ +\ + const inc_t rs_b = packnr; \ +\ + /* Assume that the degree of duplication is equal to packnr / nr. */ \ + const inc_t cs_b = packnr / nr; \ +\ + dim_t iter, i, j, l; \ + dim_t n_behind; \ +\ + for ( iter = 0; iter < m; ++iter ) \ + { \ + i = m - iter - 1; \ + n_behind = iter; \ +\ + ctype* restrict alpha11 = a + (i )*rs_a + (i )*cs_a; \ + ctype* restrict a12t = a + (i )*rs_a + (i+1)*cs_a; \ + ctype* restrict b1 = b + (i )*rs_b + (0 )*cs_b; \ + ctype* restrict B2 = b + (i+1)*rs_b + (0 )*cs_b; \ +\ + /* b1 = b1 - a12t * B2; */ \ + /* b1 = b1 / alpha11; */ \ + for ( j = 0; j < n; ++j ) \ + { \ + ctype* restrict beta11 = b1 + (0 )*rs_b + (j )*cs_b; \ + ctype* restrict b21 = B2 + (0 )*rs_b + (j )*cs_b; \ + ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ + ctype beta11c = *beta11; \ + ctype rho11; \ +\ + /* beta11 = beta11 - a12t * b21; */ \ + PASTEMAC(ch,set0s)( rho11 ); \ + for ( l = 0; l < n_behind; ++l ) \ + { \ + ctype* restrict alpha12 = a12t + (l )*cs_a; \ + ctype* restrict beta21 = b21 + (l )*rs_b; \ +\ + PASTEMAC(ch,axpys)( *alpha12, *beta21, rho11 ); \ + } \ + PASTEMAC(ch,subs)( rho11, beta11c ); \ +\ + /* beta11 = beta11 / alpha11; */ \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11, beta11c ); \ +\ + /* Output final result to matrix c. */ \ + PASTEMAC(ch,copys)( beta11c, *gamma11 ); \ +\ + /* Store the local value back to b11. */ \ + PASTEMAC(ch,copys)( beta11c, *beta11 ); \ + } \ + } \ +} + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +INSERT_GENTFUNC_BASIC3( trsmbb_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, scals ) +#else +INSERT_GENTFUNC_BASIC3( trsmbb_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, invscals ) +#endif + diff --git a/ref_kernels/3/bli_gemm_ref.c b/ref_kernels/3/bli_gemm_ref.c index 931fe994b3..51ff9df4bd 100644 --- a/ref_kernels/3/bli_gemm_ref.c +++ b/ref_kernels/3/bli_gemm_ref.c @@ -44,6 +44,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -107,8 +109,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( dim_t i = 0; i < mr; ++i ) \ - for ( dim_t j = 0; j < nr; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ PASTEMAC(ch,copys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -117,8 +119,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( dim_t i = 0; i < mr; ++i ) \ - for ( dim_t j = 0; j < nr; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ PASTEMAC(ch,xpbys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -133,8 +135,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( dim_t j = 0; j < nr; ++j ) \ - for ( dim_t i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ PASTEMAC(ch,copys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -143,8 +145,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( dim_t j = 0; j < nr; ++j ) \ - for ( dim_t i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ PASTEMAC(ch,xpbys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -171,6 +173,8 @@ GENTFUNC( dcomplex, z, gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -188,9 +192,6 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ \ const inc_t cs_a = packmr; \ \ diff --git a/ref_kernels/3/bli_gemmsup_ref.c b/ref_kernels/3/bli_gemmsup_ref.c new file mode 100644 index 0000000000..0c3773c1c0 --- /dev/null +++ b/ref_kernels/3/bli_gemmsup_ref.c @@ -0,0 +1,832 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- General storage case ----------------------------------------------------- +// + +INSERT_GENTFUNC_BASIC2( gemmsup_g, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + + + + + + + + +#if 0 + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t mn = m * n; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = n; \ + const inc_t cs_ab = 1; \ +\ +\ + /* Assumptions: m <= mr, n <= nr so that the temporary array ab is + sufficiently large enough to hold the m x n microtile. + + The ability to handle m < mr and n < nr is being provided so that + optimized ukernels can call one of these reference implementations + for their edge cases, if they choose. When they do so, they will + need to call the function directly, by its configuration-mangled + name, since it will have been overwritten in the context when + the optimized ukernel functions are registered. */ \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,set0s)( ab[i] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + /* Traverse ab by rows; assume cs_ab = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,dots) \ + ( \ + a[ i*rs_a ], \ + b[ j*cs_b ], \ + ab[ i*rs_ab + j*cs_ab ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, ab[i] ); \ + } \ +\ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c with the + result in ab. Otherwise, scale by beta and accumulate ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + /* Traverse ab and c by rows; assume cs_a = cs_a = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,adds) \ + ( \ + ab[ i*rs_ab + j*1 ], \ + c[ i*rs_c + j*1 ] \ + ) \ + } \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ +\ + /* Traverse ab and c by rows; assume cs_a = cs_a = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,copys) \ + ( \ + ab[ i*rs_ab + j*1 ], \ + c[ i*rs_c + j*1 ] \ + ) \ + } \ + } \ + else /* beta != 0 && beta != 1 */ \ + { \ + /* Traverse ab and c by rows; assume cs_a = cs_a = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,xpbys) \ + ( \ + ab[ i*rs_ab + j*1 ], \ + *beta, \ + c[ i*rs_c + j*1 ] \ + ) \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t mn = m * n; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = 1; \ + const inc_t cs_ab = m; \ +\ +\ + /* Assumptions: m <= mr, n <= nr so that the temporary array ab is + sufficiently large enough to hold the m x n microtile. + + The ability to handle m < mr and n < nr is being provided so that + optimized ukernels can call one of these reference implementations + for their edge cases, if they choose. When they do so, they will + need to call the function directly, by its configuration-mangled + name, since it will have been overwritten in the context when + the optimized ukernel functions are registered. */ \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,set0s)( ab[i] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + /* Traverse ab by columns; assume rs_ab = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,dots) \ + ( \ + a[ i*rs_a ], \ + b[ j*cs_b ], \ + ab[ i*rs_ab + j*cs_ab ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, ab[i] ); \ + } \ +\ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c with the + result in ab. Otherwise, scale by beta and accumulate ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + /* Traverse ab and c by columns; assume rs_a = rs_a = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,adds) \ + ( \ + ab[ i*1 + j*cs_ab ], \ + c[ i*1 + j*cs_c ] \ + ) \ + } \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* Traverse ab and c by columns; assume rs_a = rs_a = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,copys) \ + ( \ + ab[ i*1 + j*cs_ab ], \ + c[ i*1 + j*cs_c ] \ + ) \ + } \ + } \ + else /* beta != 0 && beta != 1 */ \ + { \ + /* Traverse ab and c by columns; assume rs_a = rs_a = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,xpbys) \ + ( \ + ab[ i*1 + j*cs_ab ], \ + *beta, \ + c[ i*1 + j*cs_c ] \ + ) \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- General storage case ----------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t mn = m * n; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = 1; \ + const inc_t cs_ab = m; \ +\ +\ + /* Assumptions: m <= mr, n <= nr so that the temporary array ab is + sufficiently large enough to hold the m x n microtile. + + The ability to handle m < mr and n < nr is being provided so that + optimized ukernels can call one of these reference implementations + for their edge cases, if they choose. When they do so, they will + need to call the function directly, by its configuration-mangled + name, since it will have been overwritten in the context when + the optimized ukernel functions are registered. */ \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,set0s)( ab[i] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + /* General storage: doesn't matter how we traverse ab. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,dots) \ + ( \ + a[ i*rs_a ], \ + b[ j*cs_b ], \ + ab[ i*rs_ab + j*cs_ab ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, ab[i] ); \ + } \ +\ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c with the + result in ab. Otherwise, scale by beta and accumulate ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + /* General storage: doesn't matter how we traverse ab and c. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,adds) \ + ( \ + ab[ i*rs_ab + j*cs_ab ], \ + c[ i*rs_c + j*cs_c ] \ + ) \ + } \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* General storage: doesn't matter how we traverse ab and c. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,copys) \ + ( \ + ab[ i*rs_ab + j*cs_ab ], \ + c[ i*rs_c + j*cs_c ] \ + ) \ + } \ + } \ + else /* beta != 0 && beta != 1 */ \ + { \ + /* General storage: doesn't matter how we traverse ab and c. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,xpbys) \ + ( \ + ab[ i*rs_ab + j*cs_ab ], \ + *beta, \ + c[ i*rs_c + j*cs_c ] \ + ) \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_g, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +#endif diff --git a/ref_kernels/3/bli_gemmtrsm_ref.c b/ref_kernels/3/bli_gemmtrsm_ref.c index 2b756963e4..30fc3fcd6d 100644 --- a/ref_kernels/3/bli_gemmtrsm_ref.c +++ b/ref_kernels/3/bli_gemmtrsm_ref.c @@ -39,6 +39,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a1x, \ @@ -51,6 +53,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ ) \ { \ const num_t dt = PASTEMAC(ch,type); \ +\ + const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ @@ -63,11 +68,35 @@ void PASTEMAC3(ch,opname,arch,suf) \ gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ PASTECH(ch,trsm_ukr_ft) \ trsm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, trsmkerid, cntx ); \ +\ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + /* FGVZ: Should we be querying the preference of BLIS_GEMMTRSM_?_UKR + instead? */ \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : nr ); \ + const inc_t cs_ct = ( col_pref ? mr : 1 ); \ +\ + const bool use_ct = ( m < mr || n < nr ); \ +\ + ctype* restrict c11_use = c11; \ + inc_t rs_c_use = rs_c; \ + inc_t cs_c_use = cs_c; \ +\ + if ( use_ct ) \ + { \ + c11_use = ct; \ + rs_c_use = rs_ct; \ + cs_c_use = cs_ct; \ + } \ \ /* lower: b11 = alpha * b11 - a10 * b01; */ \ /* upper: b11 = alpha * b11 - a12 * b21; */ \ gemm_ukr \ ( \ + m, \ + n, \ k, \ minus_one, \ a1x, \ @@ -84,10 +113,20 @@ void PASTEMAC3(ch,opname,arch,suf) \ ( \ a11, \ b11, \ - c11, rs_c, cs_c, \ + c11_use, rs_c_use, cs_c_use, \ data, \ cntx \ ); \ +\ + if ( use_ct ) \ + { \ + PASTEMAC(ch,copys_mxn) \ + ( \ + m, n, \ + ct, rs_ct, cs_ct, \ + c11, rs_c, cs_c \ + ); \ + } \ \ /* PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b0111p_r after", k+3, 8, \ diff --git a/ref_kernels/3/bli_trsm_ref.c b/ref_kernels/3/bli_trsm_ref.c index 0cfa74c468..786f1129d0 100644 --- a/ref_kernels/3/bli_trsm_ref.c +++ b/ref_kernels/3/bli_trsm_ref.c @@ -39,126 +39,7 @@ // An implementation that attempts to facilitate emission of vectorized // instructions via constant loop bounds + #pragma omp simd directives. -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, arch, suf, mr, nr ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t rs_a = 1; \ - const inc_t cs_a = mr; \ -\ - const inc_t rs_b = nr; \ - const inc_t cs_b = 1; \ -\ - PRAGMA_SIMD \ - for ( dim_t i = 0; i < mr; ++i ) \ - { \ - /* b1 = b1 - a10t * B0; */ \ - /* b1 = b1 / alpha11; */ \ - for ( dim_t j = 0; j < nr; ++j ) \ - { \ - ctype beta11c = b[i*rs_b + j*cs_b]; \ - ctype rho11; \ -\ - /* beta11 = beta11 - a10t * b01; */ \ - PASTEMAC(ch,set0s)( rho11 ); \ - for ( dim_t l = 0; l < i; ++l ) \ - { \ - PASTEMAC(ch,axpys)( a[i*rs_a + l*cs_a], \ - b[l*rs_b + j*cs_b], rho11 ); \ - } \ - PASTEMAC(ch,subs)( rho11, beta11c ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scals)( a[i*rs_a + i*cs_a], beta11c ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,copys)( beta11c, c[i*rs_c + j*cs_c] ); \ -\ - /* Store the local value back to b11. */ \ - PASTEMAC(ch,copys)( beta11c, b[i*rs_b + j*cs_b] ); \ - } \ - } \ -} - -//INSERT_GENTFUNC_BASIC2( trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) -GENTFUNC( float, s, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 16 ) -GENTFUNC( double, d, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) -GENTFUNC( scomplex, c, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) -GENTFUNC( dcomplex, z, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, arch, suf, mr, nr ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t rs_a = 1; \ - const inc_t cs_a = mr; \ -\ - const inc_t rs_b = nr; \ - const inc_t cs_b = 1; \ -\ - PRAGMA_SIMD \ - for ( dim_t iter = 0; iter < mr; ++iter ) \ - { \ - dim_t i = mr - iter - 1; \ -\ - /* b1 = b1 - a12t * B2; */ \ - /* b1 = b1 / alpha11; */ \ - for ( dim_t j = 0; j < nr; ++j ) \ - { \ - ctype beta11c = b[i*rs_b + j*cs_b]; \ - ctype rho11; \ -\ - /* beta11 = beta11 - a12t * b21; */ \ - PASTEMAC(ch,set0s)( rho11 ); \ - for ( dim_t l = 0; l < iter; ++l ) \ - { \ - PASTEMAC(ch,axpys)( a[i*rs_a + (i+1+l)*cs_a], \ - b[(i+1+l)*rs_b + j*cs_b], rho11 ); \ - } \ - PASTEMAC(ch,subs)( rho11, beta11c ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scals)( a[i*rs_a + i*cs_a], beta11c ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,copys)( beta11c, c[i*rs_c + j*cs_c] ); \ -\ - /* Store the local value back to b11. */ \ - PASTEMAC(ch,copys)( beta11c, b[i*rs_b + j*cs_b] ); \ - } \ - } \ -} - -//INSERT_GENTFUNC_BASIC2( trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) -GENTFUNC( float, s, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 16 ) -GENTFUNC( double, d, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) -GENTFUNC( scomplex, c, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) -GENTFUNC( dcomplex, z, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) +// (Deleted. See 'old' directory.) #else @@ -166,7 +47,7 @@ GENTFUNC( dcomplex, z, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) // and makes no use of #pragma omp simd. #undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +#define GENTFUNC( ctype, ch, opname, arch, suf, diagop ) \ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ @@ -229,11 +110,11 @@ void PASTEMAC3(ch,opname,arch,suf) \ PASTEMAC(ch,subs)( rho11, beta11c ); \ \ /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scals)( *alpha11, beta11c ); \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11, beta11c ); \ \ /* Output final result to matrix c. */ \ PASTEMAC(ch,copys)( beta11c, *gamma11 ); \ @@ -244,11 +125,15 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ } -INSERT_GENTFUNC_BASIC2( trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +INSERT_GENTFUNC_BASIC3( trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, scals ) +#else +INSERT_GENTFUNC_BASIC3( trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, invscals ) +#endif #undef GENTFUNC -#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +#define GENTFUNC( ctype, ch, opname, arch, suf, diagop ) \ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ @@ -311,11 +196,11 @@ void PASTEMAC3(ch,opname,arch,suf) \ PASTEMAC(ch,subs)( rho11, beta11c ); \ \ /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scals)( *alpha11, beta11c ); \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11, beta11c ); \ \ /* Output final result to matrix c. */ \ PASTEMAC(ch,copys)( beta11c, *gamma11 ); \ @@ -326,6 +211,10 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ } -INSERT_GENTFUNC_BASIC2( trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +INSERT_GENTFUNC_BASIC3( trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, scals ) +#else +INSERT_GENTFUNC_BASIC3( trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, invscals ) +#endif #endif diff --git a/ref_kernels/3/old/bli_trsm_simd_ref.c b/ref_kernels/3/old/bli_trsm_simd_ref.c new file mode 100644 index 0000000000..e656df96cc --- /dev/null +++ b/ref_kernels/3/old/bli_trsm_simd_ref.c @@ -0,0 +1,165 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#if 1 + +// An implementation that attempts to facilitate emission of vectorized +// instructions via constant loop bounds + #pragma omp simd directives. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf, mr, nr ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + ctype* restrict a, \ + ctype* restrict b, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = mr; \ +\ + const inc_t rs_b = nr; \ + const inc_t cs_b = 1; \ +\ + PRAGMA_SIMD \ + for ( dim_t i = 0; i < mr; ++i ) \ + { \ + /* b1 = b1 - a10t * B0; */ \ + /* b1 = b1 / alpha11; */ \ + for ( dim_t j = 0; j < nr; ++j ) \ + { \ + ctype beta11c = b[i*rs_b + j*cs_b]; \ + ctype rho11; \ +\ + /* beta11 = beta11 - a10t * b01; */ \ + PASTEMAC(ch,set0s)( rho11 ); \ + for ( dim_t l = 0; l < i; ++l ) \ + { \ + PASTEMAC(ch,axpys)( a[i*rs_a + l*cs_a], \ + b[l*rs_b + j*cs_b], rho11 ); \ + } \ + PASTEMAC(ch,subs)( rho11, beta11c ); \ +\ + /* beta11 = beta11 / alpha11; */ \ + /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead + of alpha11, so we can multiply rather than divide. We store + the inverse of alpha11 intentionally to avoid expensive + division instructions within the micro-kernel. */ \ + PASTEMAC(ch,scals)( a[i*rs_a + i*cs_a], beta11c ); \ +\ + /* Output final result to matrix c. */ \ + PASTEMAC(ch,copys)( beta11c, c[i*rs_c + j*cs_c] ); \ +\ + /* Store the local value back to b11. */ \ + PASTEMAC(ch,copys)( beta11c, b[i*rs_b + j*cs_b] ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC2( trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) +GENTFUNC( float, s, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 16 ) +GENTFUNC( double, d, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) +GENTFUNC( scomplex, c, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) +GENTFUNC( dcomplex, z, trsm_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf, mr, nr ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + ctype* restrict a, \ + ctype* restrict b, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = mr; \ +\ + const inc_t rs_b = nr; \ + const inc_t cs_b = 1; \ +\ + PRAGMA_SIMD \ + for ( dim_t iter = 0; iter < mr; ++iter ) \ + { \ + dim_t i = mr - iter - 1; \ +\ + /* b1 = b1 - a12t * B2; */ \ + /* b1 = b1 / alpha11; */ \ + for ( dim_t j = 0; j < nr; ++j ) \ + { \ + ctype beta11c = b[i*rs_b + j*cs_b]; \ + ctype rho11; \ +\ + /* beta11 = beta11 - a12t * b21; */ \ + PASTEMAC(ch,set0s)( rho11 ); \ + for ( dim_t l = 0; l < iter; ++l ) \ + { \ + PASTEMAC(ch,axpys)( a[i*rs_a + (i+1+l)*cs_a], \ + b[(i+1+l)*rs_b + j*cs_b], rho11 ); \ + } \ + PASTEMAC(ch,subs)( rho11, beta11c ); \ +\ + /* beta11 = beta11 / alpha11; */ \ + /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead + of alpha11, so we can multiply rather than divide. We store + the inverse of alpha11 intentionally to avoid expensive + division instructions within the micro-kernel. */ \ + PASTEMAC(ch,scals)( a[i*rs_a + i*cs_a], beta11c ); \ +\ + /* Output final result to matrix c. */ \ + PASTEMAC(ch,copys)( beta11c, c[i*rs_c + j*cs_c] ); \ +\ + /* Store the local value back to b11. */ \ + PASTEMAC(ch,copys)( beta11c, b[i*rs_b + j*cs_b] ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC2( trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) +GENTFUNC( float, s, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 16 ) +GENTFUNC( double, d, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) +GENTFUNC( scomplex, c, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 8 ) +GENTFUNC( dcomplex, z, trsm_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) + +#else + +#endif diff --git a/ref_kernels/bli_cntx_ref.c b/ref_kernels/bli_cntx_ref.c index 7b5aa43ef8..33e74ecaa8 100644 --- a/ref_kernels/bli_cntx_ref.c +++ b/ref_kernels/bli_cntx_ref.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -47,7 +47,7 @@ // -- Level-3 native micro-kernel prototype redefinitions ---------------------- -// -- prototypes for completely generic level-3 microkernels -- +// -- Prototypes for completely generic level-3 microkernels -- #undef gemm_ukr_name #define gemm_ukr_name GENARNAME(gemm) @@ -66,46 +66,7 @@ // -- Level-3 virtual micro-kernel prototype redefinitions --------------------- -// -- 3mh -- - -#undef gemm3mh_ukr_name -#define gemm3mh_ukr_name GENARNAME(gemm3mh) - -// -- 3m1 -- - -#undef gemm3m1_ukr_name -#define gemm3m1_ukr_name GENARNAME(gemm3m1) -#undef gemmtrsm3m1_l_ukr_name -#define gemmtrsm3m1_l_ukr_name GENARNAME(gemmtrsm3m1_l) -#undef gemmtrsm3m1_u_ukr_name -#define gemmtrsm3m1_u_ukr_name GENARNAME(gemmtrsm3m1_u) -#undef trsm3m1_l_ukr_name -#define trsm3m1_l_ukr_name GENARNAME(trsm3m1_l) -#undef trsm3m1_u_ukr_name -#define trsm3m1_u_ukr_name GENARNAME(trsm3m1_u) - -// -- 4mh -- - -#undef gemm4mh_ukr_name -#define gemm4mh_ukr_name GENARNAME(gemm4mh) - -// -- 4mb -- - -#undef gemm4mb_ukr_name -#define gemm4mb_ukr_name GENARNAME(gemm4mb) - -// -- 4m1 -- - -#undef gemm4m1_ukr_name -#define gemm4m1_ukr_name GENARNAME(gemm4m1) -#undef gemmtrsm4m1_l_ukr_name -#define gemmtrsm4m1_l_ukr_name GENARNAME(gemmtrsm4m1_l) -#undef gemmtrsm4m1_u_ukr_name -#define gemmtrsm4m1_u_ukr_name GENARNAME(gemmtrsm4m1_u) -#undef trsm4m1_l_ukr_name -#define trsm4m1_l_ukr_name GENARNAME(trsm4m1_l) -#undef trsm4m1_u_ukr_name -#define trsm4m1_u_ukr_name GENARNAME(trsm4m1_u) +// -- Prototypes for induced method level-3 microkernels -- // -- 1m -- @@ -124,6 +85,26 @@ // template. #include "bli_l3_ind_ukr.h" +// -- Level-3 small/unpacked micro-kernel prototype definitions ---------------- + +// NOTE: This results in redundant prototypes for gemmsup_r and gemmsup_c +// kernels, but since they will be identical the compiler won't complain. + +#undef gemmsup_rv_ukr_name +#define gemmsup_rv_ukr_name GENARNAME(gemmsup_r) +#undef gemmsup_rg_ukr_name +#define gemmsup_rg_ukr_name GENARNAME(gemmsup_r) +#undef gemmsup_cv_ukr_name +#define gemmsup_cv_ukr_name GENARNAME(gemmsup_c) +#undef gemmsup_cg_ukr_name +#define gemmsup_cg_ukr_name GENARNAME(gemmsup_c) + +#undef gemmsup_gx_ukr_name +#define gemmsup_gx_ukr_name GENARNAME(gemmsup_g) + +// Include the small/unpacked kernel API template. +#include "bli_l3_sup_ker.h" + // -- Level-1m (packm/unpackm) kernel prototype redefinitions ------------------ #undef packm_2xk_ker_name @@ -164,59 +145,6 @@ #undef unpackm_16xk_ker_name #define unpackm_16xk_ker_name GENARNAME(unpackm_16xk) -#undef packm_2xk_3mis_ker_name -#define packm_2xk_3mis_ker_name GENARNAME(packm_2xk_3mis) -#undef packm_4xk_3mis_ker_name -#define packm_4xk_3mis_ker_name GENARNAME(packm_4xk_3mis) -#undef packm_6xk_3mis_ker_name -#define packm_6xk_3mis_ker_name GENARNAME(packm_6xk_3mis) -#undef packm_8xk_3mis_ker_name -#define packm_8xk_3mis_ker_name GENARNAME(packm_8xk_3mis) -#undef packm_10xk_3mis_ker_name -#define packm_10xk_3mis_ker_name GENARNAME(packm_10xk_3mis) -#undef packm_12xk_3mis_ker_name -#define packm_12xk_3mis_ker_name GENARNAME(packm_12xk_3mis) -#undef packm_14xk_3mis_ker_name -#define packm_14xk_3mis_ker_name GENARNAME(packm_14xk_3mis) -#undef packm_16xk_3mis_ker_name -#define packm_16xk_3mis_ker_name GENARNAME(packm_16xk_3mis) - -#undef packm_2xk_4mi_ker_name -#define packm_2xk_4mi_ker_name GENARNAME(packm_2xk_4mi) -#undef packm_3xk_4mi_ker_name -#define packm_3xk_4mi_ker_name GENARNAME(packm_3xk_4mi) -#undef packm_4xk_4mi_ker_name -#define packm_4xk_4mi_ker_name GENARNAME(packm_4xk_4mi) -#undef packm_6xk_4mi_ker_name -#define packm_6xk_4mi_ker_name GENARNAME(packm_6xk_4mi) -#undef packm_8xk_4mi_ker_name -#define packm_8xk_4mi_ker_name GENARNAME(packm_8xk_4mi) -#undef packm_10xk_4mi_ker_name -#define packm_10xk_4mi_ker_name GENARNAME(packm_10xk_4mi) -#undef packm_12xk_4mi_ker_name -#define packm_12xk_4mi_ker_name GENARNAME(packm_12xk_4mi) -#undef packm_14xk_4mi_ker_name -#define packm_14xk_4mi_ker_name GENARNAME(packm_14xk_4mi) -#undef packm_16xk_4mi_ker_name -#define packm_16xk_4mi_ker_name GENARNAME(packm_16xk_4mi) - -#undef packm_2xk_rih_ker_name -#define packm_2xk_rih_ker_name GENARNAME(packm_2xk_rih) -#undef packm_4xk_rih_ker_name -#define packm_4xk_rih_ker_name GENARNAME(packm_4xk_rih) -#undef packm_6xk_rih_ker_name -#define packm_6xk_rih_ker_name GENARNAME(packm_6xk_rih) -#undef packm_8xk_rih_ker_name -#define packm_8xk_rih_ker_name GENARNAME(packm_8xk_rih) -#undef packm_10xk_rih_ker_name -#define packm_10xk_rih_ker_name GENARNAME(packm_10xk_rih) -#undef packm_12xk_rih_ker_name -#define packm_12xk_rih_ker_name GENARNAME(packm_12xk_rih) -#undef packm_14xk_rih_ker_name -#define packm_14xk_rih_ker_name GENARNAME(packm_14xk_rih) -#undef packm_16xk_rih_ker_name -#define packm_16xk_rih_ker_name GENARNAME(packm_16xk_rih) - #undef packm_2xk_1er_ker_name #define packm_2xk_1er_ker_name GENARNAME(packm_2xk_1er) #undef packm_4xk_1er_ker_name @@ -295,16 +223,33 @@ // -- Macros to help concisely instantiate bli_func_init() --------------------- #define gen_func_init_co( func_p, opname ) \ -\ +{ \ bli_func_init( func_p, NULL, NULL, \ - PASTEMAC(c,opname), PASTEMAC(z,opname) ) + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ +} #define gen_func_init( func_p, opname ) \ -\ +{ \ bli_func_init( func_p, PASTEMAC(s,opname), PASTEMAC(d,opname), \ - PASTEMAC(c,opname), PASTEMAC(z,opname) ) + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ +} +#define gen_sup_func_init( func0_p, func1_p, opname ) \ +{ \ + bli_func_init( func0_p, PASTEMAC(s,opname), PASTEMAC(d,opname), \ + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ + bli_func_init( func1_p, PASTEMAC(s,opname), PASTEMAC(d,opname), \ + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ +} +// -- Helper function for 1m --------------------------------------------------- + +void GENBAINAME(cntx_init_blkszs) + ( + ind_t method, + num_t dt, + cntx_t* cntx + ); // ----------------------------------------------------------------------------- @@ -314,9 +259,11 @@ void GENBARNAME(cntx_init) ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; func_t* funcs; mbool_t* mbools; dim_t i; + void** vfuncs; // -- Clear the context ---------------------------------------------------- @@ -363,6 +310,11 @@ void GENBARNAME(cntx_init) funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); + // NOTE: We set the virtual micro-kernel slots to contain the addresses + // of the native micro-kernels. In general, the ukernels in the virtual + // ukernel slots are always called, and if the function called happens to + // be a virtual micro-kernel, it will then know to find its native ukernel + // (i.e., in the native ukernel slots). gen_func_init( &funcs[ BLIS_GEMM_UKR ], gemm_ukr_name ); gen_func_init( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm_l_ukr_name ); gen_func_init( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm_u_ukr_name ); @@ -381,13 +333,104 @@ void GENBARNAME(cntx_init) gen_func_init( &funcs[ BLIS_TRSM_L_UKR ], trsm_l_ukr_name ); gen_func_init( &funcs[ BLIS_TRSM_U_UKR ], trsm_u_ukr_name ); - bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE ); + // s d c z + bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE ); bli_mbool_init( &mbools[ BLIS_GEMMTRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE ); bli_mbool_init( &mbools[ BLIS_GEMMTRSM_U_UKR ], FALSE, FALSE, FALSE, FALSE ); bli_mbool_init( &mbools[ BLIS_TRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE ); bli_mbool_init( &mbools[ BLIS_TRSM_U_UKR ], FALSE, FALSE, FALSE, FALSE ); + // -- Set level-3 small/unpacked thresholds -------------------------------- + + // NOTE: The default thresholds are set to zero so that the sup framework + // does not activate by default. Note that the semantic meaning of the + // thresholds is that the sup code path is executed if a dimension is + // strictly less than its corresponding threshold. So actually, the + // thresholds specify the minimum dimension size that will still dispatch + // the non-sup/large code path. This "strictly less than" behavior was + // chosen over "less than or equal to" so that threshold values of 0 would + // effectively disable sup (even for matrix dimensions of 0). + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 0, 0, 0, 0 ); + + // Initialize the context with the default thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + + // -- Set level-3 small/unpacked handlers ---------------------------------- + + vfuncs = bli_cntx_l3_sup_handlers_buf( cntx ); + + // Initialize all of the function pointers to NULL; + for ( i = 0; i < BLIS_NUM_LEVEL3_OPS; ++i ) vfuncs[ i ] = NULL; + + // The level-3 sup handlers are oapi-based, so we only set one slot per + // operation. + + // Set the gemm slot to the default gemm sup handler. + vfuncs[ BLIS_GEMM ] = bli_gemmsup_ref; + vfuncs[ BLIS_GEMMT ] = bli_gemmtsup_ref; + + + // -- Set level-3 small/unpacked micro-kernels and preferences ------------- + + funcs = bli_cntx_l3_sup_kers_buf( cntx ); + mbools = bli_cntx_l3_sup_kers_prefs_buf( cntx ); + +#if 0 + // Adhere to the small/unpacked ukernel mappings: + // - rv -> rrr, rcr + // - rg -> rrc, rcc + // - cv -> ccr, ccc + // - cg -> crr, crc + gen_sup_func_init( &funcs[ BLIS_RRR ], + &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_RRC ], + &funcs[ BLIS_RCC ], gemmsup_rg_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_CCR ], + &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_CRR ], + &funcs[ BLIS_CRC ], gemmsup_cg_ukr_name ); +#endif + gen_func_init( &funcs[ BLIS_RRR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RRC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RCC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCC ], gemmsup_rv_ukr_name ); + + // Register the general-stride/generic ukernel to the "catch-all" slot + // associated with the BLIS_XXX enum value. This slot will be queried if + // *any* operand is stored with general stride. + gen_func_init( &funcs[ BLIS_XXX ], gemmsup_gx_ukr_name ); + + + // Set the l3 sup ukernel storage preferences. + // s d c z + bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CCC ], TRUE, TRUE, TRUE, TRUE ); + + bli_mbool_init( &mbools[ BLIS_XXX ], TRUE, TRUE, TRUE, TRUE ); + + // -- Set level-1f kernels ------------------------------------------------- funcs = bli_cntx_l1f_kers_buf( cntx ); @@ -461,14 +504,6 @@ void GENBARNAME(cntx_init) // -- Set miscellaneous fields --------------------------------------------- bli_cntx_set_method( BLIS_NAT, cntx ); - - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS, cntx ); - bli_cntx_set_schema_c_panel( BLIS_NOT_PACKED, cntx ); - - //bli_cntx_set_anti_pref( FALSE, cntx ); - - //bli_cntx_set_membrk( bli_membrk_query(), cntx ); } // ----------------------------------------------------------------------------- @@ -476,7 +511,6 @@ void GENBARNAME(cntx_init) void GENBAINAME(cntx_init) ( ind_t method, - num_t dt, cntx_t* cntx ) { @@ -493,41 +527,7 @@ void GENBAINAME(cntx_init) funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); - // 3mh, 4mh, and 4mb do not not support trsm. - bli_func_init_null( &funcs[ BLIS_GEMMTRSM_L_UKR ] ); - bli_func_init_null( &funcs[ BLIS_GEMMTRSM_U_UKR ] ); - bli_func_init_null( &funcs[ BLIS_TRSM_L_UKR ] ); - bli_func_init_null( &funcs[ BLIS_TRSM_U_UKR ] ); - - if ( method == BLIS_3MH ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm3mh_ukr_name ); - } - else if ( method == BLIS_3M1 ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm3m1_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm3m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm3m1_u_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_L_UKR ], trsm3m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_U_UKR ], trsm3m1_u_ukr_name ); - } - else if ( method == BLIS_4MH ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm4mh_ukr_name ); - } - else if ( method == BLIS_4M1B ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm4mb_ukr_name ); - } - else if ( method == BLIS_4M1A ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm4m1_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm4m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm4m1_u_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_L_UKR ], trsm4m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_U_UKR ], trsm4m1_u_ukr_name ); - } - else if ( method == BLIS_1M ) + if ( method == BLIS_1M ) { gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm1m_ukr_name ); gen_func_init_co( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm1m_l_ukr_name ); @@ -546,7 +546,14 @@ void GENBAINAME(cntx_init) // For 1m, we employ an optimization which requires that we copy the native // real domain gemm ukernel function pointers to the corresponding real - // domain slots in the virtual gemm ukernel func_t. + // domain slots in the virtual gemm ukernel func_t. This optimization allows + // us to, under certain conditions, adjust various parameters within the gemm + // macrokernel so that the real-domain macrokernel (which will query and use + // the real-domain virtual gemm ukernel) can be called instead of calling the + // complex-domain macrokernel and the corresponding complex-domain virtual + // microkernel. The non-optimized code path would require an extra level of + // function call overhead, which can be avoided in most cases (i.e., when + // beta has a zero imaginary component and C is either row- or column-stored). if ( method == BLIS_1M ) { func_t* gemm_nat_ukrs = bli_cntx_get_l3_nat_ukrs( BLIS_GEMM_UKR, cntx ); @@ -567,40 +574,7 @@ void GENBAINAME(cntx_init) bli_func_init_null( &funcs[ i ] ); } - if ( method == BLIS_3MH || method == BLIS_4MH ) - { - gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_6XK_KER ], packm_6xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_8XK_KER ], packm_8xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_10XK_KER ], packm_10xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_12XK_KER ], packm_12xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_14XK_KER ], packm_14xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_16XK_KER ], packm_16xk_rih_ker_name ); - } - else if ( method == BLIS_3M1 ) - { - gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_6XK_KER ], packm_6xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_8XK_KER ], packm_8xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_10XK_KER ], packm_10xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_12XK_KER ], packm_12xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_14XK_KER ], packm_14xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_16XK_KER ], packm_16xk_3mis_ker_name ); - } - else if ( method == BLIS_4M1A || method == BLIS_4M1B ) - { - gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_6XK_KER ], packm_6xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_8XK_KER ], packm_8xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_10XK_KER ], packm_10xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_12XK_KER ], packm_12xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_14XK_KER ], packm_14xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_16XK_KER ], packm_16xk_4mi_ker_name ); - } - else if ( method == BLIS_1M ) + if ( method == BLIS_1M ) { gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_1er_ker_name ); gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_1er_ker_name ); @@ -630,191 +604,75 @@ void GENBAINAME(cntx_init) // Modify the context with cache and register blocksizes (and multiples) // appropriate for the current induced method. - if ( method == BLIS_3MH ) + if ( method == BLIS_1M ) { - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 1.0, 1.0, - BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); + //const bool is_pb = FALSE; + + // Call a helper function to initialize blocksizes for each complex + // datatype. + GENBAINAME(cntx_init_blkszs)( method, BLIS_SCOMPLEX, cntx ); + GENBAINAME(cntx_init_blkszs)( method, BLIS_DCOMPLEX, cntx ); } - else if ( method == BLIS_3M1 ) + else // if ( method == BLIS_NAT ) { - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 3.0, 3.0, - BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); + // No change in blocksizes needed for native execution. } - else if ( method == BLIS_4MH ) +} + +// ----------------------------------------------------------------------------- + +void GENBAINAME(cntx_init_blkszs) + ( + ind_t method, + num_t dt, + cntx_t* cntx + ) +{ + // We MUST set the induced method in the context prior to calling + // bli_cntx_l3_vir_ukr_prefers_cols_dt() because that function queries + // the induced method. That function needs the induced method value in + // order to determine whether to evaluate the "prefers column storage" + // predicate using the storage preference of the kernel for dt, or + // the storage preference of the kernel for the real projection of + // dt. Failing to set the induced method here can lead to strange + // undefined behavior at runtime if the native complex kernel's + // storage preference happens to not equal that of the native real + // kernel. + bli_cntx_set_method( method, cntx ); + + // Initialize the blocksizes according to the micro-kernel preference as + // well as the algorithm. + if ( bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ) ) { + // This branch is used for algorithm 1m_c_bp. + bli_cntx_set_ind_blkszs ( - method, 6, + method, dt, 6, BLIS_NC, 1.0, 1.0, - BLIS_KC, 1.0, 1.0, - BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); - } - else if ( method == BLIS_4M1B ) - { - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 2.0, 2.0, - BLIS_KC, 1.0, 1.0, - BLIS_MC, 2.0, 2.0, + BLIS_KC, 2.0, 2.0, // halve kc... + BLIS_MC, 2.0, 2.0, // halve mc... BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, + BLIS_MR, 2.0, 1.0, // ...and mr (but NOT packmr) BLIS_KR, 1.0, 1.0, cntx ); } - else if ( method == BLIS_4M1A ) + else // if ( bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, cntx ) ) { + // This branch is used for algorithm 1m_r_bp. + bli_cntx_set_ind_blkszs ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 2.0, 2.0, + method, dt, 6, + BLIS_NC, 2.0, 2.0, // halve nc... + BLIS_KC, 2.0, 2.0, // halve kc... BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, + BLIS_NR, 2.0, 1.0, // ...and nr (but NOT packnr) BLIS_MR, 1.0, 1.0, BLIS_KR, 1.0, 1.0, cntx ); } - else if ( method == BLIS_1M ) - { - const bool_t is_pb = FALSE; - - // We MUST set the induced method in the context prior to calling - // bli_cntx_l3_ukr_prefers_cols_dt() because that function queries - // the induced method. It needs the induced method value in order - // to determine whether to evaluate the "prefers column storage" - // predicate using the storage preference of the kernel for dt, or - // the storage preference of the kernel for the real projection of - // dt. Failing to set the induced method here can lead to strange - // undefined behavior at runtime if the native complex kernel's - // storage preference happens to not equal that of the native real - // kernel. - bli_cntx_set_method( method, cntx ); - - // Initialize the blocksizes according to the micro-kernel preference as - // well as the algorithm. - if ( bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ) ) - { - // This branch is used for algorithms 1m_c_bp, 1m_r_pb. - - // Set the pack_t schemas for the c_bp or r_pb algorithms. - if ( !is_pb ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_1E, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_1R, cntx ); - } - else // if ( is_pb ) - { - bli_cntx_set_schema_b_panel( BLIS_PACKED_ROW_PANELS_1R, cntx ); - bli_cntx_set_schema_a_block( BLIS_PACKED_COL_PANELS_1E, cntx ); - } - - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 2.0, 2.0, // halve kc... - BLIS_MC, 2.0, 2.0, // halve mc... - BLIS_NR, 1.0, 1.0, - BLIS_MR, 2.0, 1.0, // ...and mr (but NOT packmr) - BLIS_KR, 1.0, 1.0, - cntx - ); - } - else // if ( bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, cntx ) ) - { - // This branch is used for algorithms 1m_r_bp, 1m_c_pb. - - // Set the pack_t schemas for the r_bp or c_pb algorithms. - if ( !is_pb ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_1R, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_1E, cntx ); - } - else // if ( is_pb ) - { - bli_cntx_set_schema_b_panel( BLIS_PACKED_ROW_PANELS_1E, cntx ); - bli_cntx_set_schema_a_block( BLIS_PACKED_COL_PANELS_1R, cntx ); - } - - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 2.0, 2.0, // halve nc... - BLIS_KC, 2.0, 2.0, // halve kc... - BLIS_MC, 1.0, 1.0, - BLIS_NR, 2.0, 1.0, // ...and nr (but NOT packnr) - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); - } - } - else // if ( method == BLIS_NAT ) - { - // No change in blocksizes needed for native execution. - } - - - // -- Set misc. other fields ----------------------------------------------- - - if ( method == BLIS_3MH ) - { - // Schemas vary with _stage(). - } - else if ( method == BLIS_3M1 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_3MI, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_3MI, cntx ); - } - else if ( method == BLIS_4MH ) - { - // Schemas vary with _stage(). - } - else if ( method == BLIS_4M1A || method == BLIS_4M1B ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_4MI, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_4MI, cntx ); - } - else if ( method == BLIS_1M ) - { - //const bool_t is_pb = FALSE; - - // Set the anti-preference field to TRUE when executing a panel-block - // algorithm, and FALSE otherwise. This will cause higher-level generic - // code to establish (if needed) disagreement between the storage of C and - // the micro-kernel output preference so that the two will come back into - // agreement in the panel-block macro-kernel (which implemented in terms - // of the block-panel macro-kernel with some induced transpositions). - //bli_cntx_set_anti_pref( is_pb, cntx ); - } - else // if ( method == BLIS_NAT ) - { - } } diff --git a/ref_kernels/ind/bli_gemm1m_ref.c b/ref_kernels/ind/bli_gemm1m_ref.c index d4fefcc7ce..fbd15d695b 100644 --- a/ref_kernels/ind/bli_gemm1m_ref.c +++ b/ref_kernels/ind/bli_gemm1m_ref.c @@ -39,6 +39,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -54,11 +56,14 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ PASTECH(chr,gemm_ukr_ft) \ rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ - const bool_t col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ - const bool_t row_pref = !col_pref; \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ + const bool row_pref = !col_pref; \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const dim_t mr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ + const dim_t nr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ \ const dim_t k2 = 2 * k; \ \ @@ -84,7 +89,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ inc_t rs_c_use; \ inc_t cs_c_use; \ \ - bool_t using_ct; \ + bool using_ct; \ \ /* PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: a", mr, 2*k, \ @@ -118,6 +123,11 @@ void PASTEMAC3(ch,opname,arch,suf) \ else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \ else using_ct = FALSE; \ \ +\ + /* If we are not computing a full micro-tile, then we must write to + ct and then accumulate to c afterwards. */ \ + if ( mr != m || nr != n ) using_ct = TRUE; \ +\ \ if ( using_ct ) \ { \ @@ -149,6 +159,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ alpha_r, \ a_r, \ @@ -164,8 +176,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -173,8 +185,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -182,8 +194,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -215,6 +227,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ alpha_r, \ a_r, \ diff --git a/ref_kernels/ind/bli_gemm3m1_ref.c b/ref_kernels/ind/bli_gemm3m1_ref.c deleted file mode 100644 index a0a935a994..0000000000 --- a/ref_kernels/ind/bli_gemm3m1_ref.c +++ /dev/null @@ -1,336 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ab_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ab_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ab_rpi[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ab; \ - inc_t cs_ab; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ - ctype_r* restrict a_rpi = ( ctype_r* )a + 2*is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ - ctype_r* restrict b_rpi = ( ctype_r* )b + 2*is_b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incab, ldab; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 3m method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ab = n; n_iter = m; incc = cs_c; \ - cs_ab = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ab = 1; n_iter = n; incc = rs_c; \ - cs_ab = m; n_elem = m; ldc = cs_c; \ - } \ - incab = 1; \ - ldab = n_elem; \ -\ -\ - /* The following gemm micro-kernel calls implement all "phases" of the - 3m method: - - c = beta * c; - c_r += + a_r * b_r - a_i * b_i; - c_i += (a_r + a_i)(b_r + b_i) - a_r * b_r - a_i * b_i; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - bli_auxinfo_set_next_ab( a_i, b_i, data ); \ -\ - /* ab_r = alpha_r * a_r * b_r; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_r, \ - zero_r, \ - ab_r, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_rpi, b_rpi, data ); \ -\ - /* ab_i = alpha_r * a_i * b_i; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_i, \ - b_i, \ - zero_r, \ - ab_i, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* ct_i = alpha_r * a_ri * b_ri; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_rpi, \ - b_rpi, \ - zero_r, \ - ab_rpi, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ -\ - /* How we accumulate the intermediate matrix products stored in ab_r, - ab_i, and ab_rpi depends on the value of beta. */ \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c; - c_r = c_r + ab_r - ab_i; - c_i = c_i + ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t_r, \ - gamma11t_i, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ab_r - ab_i; - c_i = c_i + ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,addris)( gamma11t_r, \ - gamma11t_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ab_r - ab_i; - c_i = beta_r * c_i + ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t_r, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( gamma11t_i, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if ( PASTEMAC(chr,eq0)( beta_r ) ) */ \ - { \ - /* c_r = ab_r - ab_i; - c_i = ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,copyris)( gamma11t_r, \ - gamma11t_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm3m1, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/ind/bli_gemm3mh_ref.c b/ref_kernels/ind/bli_gemm3mh_ref.c deleted file mode 100644 index 1f242bc255..0000000000 --- a/ref_kernels/ind/bli_gemm3mh_ref.c +++ /dev/null @@ -1,297 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - ctype_r* restrict a_cast = ( ctype_r* )a; \ -\ - ctype_r* restrict b_cast = ( ctype_r* )b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - const pack_t schema = bli_auxinfo_schema_a( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 3mh method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ - /* The following gemm micro-kernel call implements one "phase" of the - 3m method: - - c = beta * c; - c_r += + a_r * b_r - a_i * b_i; - c_i += (a_r + a_i)(b_r + b_i) - a_r * b_r - a_i * b_i; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - /* ct = alpha_r * a * b; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_cast, \ - b_cast, \ - zero_r, \ - ct, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemm3mh_ukr: ct", 4, 4, ct, rs_ct, cs_ct, "%4.1f", "" );*/ \ -\ - /* How we accumulate the intermediate matrix product stored in ct - depends on (a) the schemas of A and B (they are always the same), - and (b) the value of beta. */ \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c; - c_r = c_r + ct; - c_i = c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t, \ - -gamma11t, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct; - c_i = c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,subs)( gamma11t, *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct; - c_i = beta_r * c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( -gamma11t, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct; - c_i = -ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_i ); \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r - ct; - c_i = c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,subs)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,subs)( gamma11t, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = -ct; - c_i = -ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_r ); \ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_i ); \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + 0; - c_i = c_i + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = 0; - c_i = ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,set0s)( *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t, *gamma11_i ); \ - } \ - } \ - } \ -\ -/*PASTEMAC(ch,fprintm)( stdout, "gemm3mh_ukr: c", 4, 4, c, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ -\ -/*PASTEMAC(chr,fprintm)( stdout, "gemm3mh_ukr: b1", k, n, b_cast, n, 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm3mh_ukr: a1", m, k, a_cast, 1, m, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm3mh, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/ref_kernels/ind/bli_gemm4m1_ref.c b/ref_kernels/ind/bli_gemm4m1_ref.c deleted file mode 100644 index e214985156..0000000000 --- a/ref_kernels/ind/bli_gemm4m1_ref.c +++ /dev/null @@ -1,291 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ct_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - ctype_r m_alpha_r = -(*alpha_r); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: ap_r", m, k, \ - a_r, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: ap_i", m, k, \ - a_i, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: bp_r", k, n, \ - b_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: bp_i", k, n, \ - b_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 4m method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ - /* The following gemm micro-kernel calls implement all "phases" of - the 4m method: - - c = beta * c; - c_r += a_r * b_r - a_i * b_i; - c_i += a_r * b_i + a_i * b_r; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - bli_auxinfo_set_next_ab( a_r, b_i, data ); \ -\ - /* ct_r = alpha_r * a_r * b_r; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_r, \ - zero_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_i, b_r, data ); \ -\ - /* ct_i = alpha_r * a_r * b_i; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_i, \ - zero_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_i, b_i, data ); \ -\ - /* ct_i += alpha_r * a_i * b_r; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_i, \ - b_r, \ - one_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* ct_r += -alpha_r * a_i * b_i; */ \ - rgemm_ukr \ - ( \ - k, \ - &m_alpha_r, \ - a_i, \ - b_i, \ - one_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ -\ - /* How we accumulate the intermediate matrix product stored in ct_r - and ct_i depends on the value of beta. */ \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t_r, \ - gamma11t_i, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct_r; */ \ - /* c_i = c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,adds)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct_r; */ \ - /* c_i = beta_r * c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t_r, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( gamma11t_i, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct_r; */ \ - /* c_i = ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t_i, *gamma11_i ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm4m1, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/ind/bli_gemm4mb_ref.c b/ref_kernels/ind/bli_gemm4mb_ref.c deleted file mode 100644 index 12a6d46649..0000000000 --- a/ref_kernels/ind/bli_gemm4mb_ref.c +++ /dev/null @@ -1,345 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ct_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - ctype_r m_alpha_r = -PASTEMAC(ch,real)( *alpha ); \ -\ - const pack_t schema_b = bli_auxinfo_schema_b( data ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 4mb method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ -\ - if ( bli_is_ro_packed( schema_b ) ) \ - { \ - /* The following gemm micro-kernel calls implement the first half of - the 4mb method (which uses b_r): - - c = beta * c; - c_r += a_r * b_r; - c_i += a_i * b_r; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ - bli_auxinfo_set_next_ab( a_i, b_r, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_r, \ - zero_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_i, \ - b_r, \ - zero_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ - } \ - else /* if ( bli_is_io_packed( schema_b ) ) */ \ - { \ - /* The following gemm micro-kernel calls implement the second half of - the 4mb method (which uses b_i): - - c_r += -a_i * b_i; - c_i += a_r * b_i; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ - bli_auxinfo_set_next_ab( a_i, b_i, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_i, \ - zero_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - &m_alpha_r, \ - a_i, \ - b_i, \ - zero_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ - } \ -\ -\ -\ - /* How we accumulate the intermediate matrix product stored in ct_r - and ct_i depends on (a) the schema of B, and (b) the value of - beta. */ \ - if ( bli_is_ro_packed( schema_b ) ) \ - { \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t_r, \ - gamma11t_i, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct_r; */ \ - /* c_i = c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,adds)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct_r; */ \ - /* c_i = beta_r * c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t_r, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( gamma11t_i, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct_r; */ \ - /* c_i = ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - } \ - else /* if ( bli_is_io_packed( schema_b ) ) */ \ - { \ - /* NOTE: If this branch executes, it means we are in the second - half of the 4mb computation in which we multiply the b_i - sub-panel by the entire block of A. Here, we know that beta - will either be equal to one (for interior cases within gemm - macro-kernel), or zero (for edge cases). */ \ -\ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct_r; */ \ - /* c_i = c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,adds)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct_r; */ \ - /* c_i = ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - } \ -\ -/*PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: b1_r", k, n, b_r, n, 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: b1_i", k, n, b_i, n, 1, "%4.1f", "" );*/ \ -/*PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: a1_r", m, k, a_r, 1, m, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: a1_i", m, k, a_i, 1, m, "%4.1f", "" );*/ \ -/*PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: ct_r", 8, 6, ct_r, rs_ct, cs_ct, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: ct_i", 8, 6, ct_i, rs_ct, cs_ct, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm4mb, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/ind/bli_gemm4mh_ref.c b/ref_kernels/ind/bli_gemm4mh_ref.c deleted file mode 100644 index afa76ce761..0000000000 --- a/ref_kernels/ind/bli_gemm4mh_ref.c +++ /dev/null @@ -1,286 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - ctype_r* restrict a_cast = ( ctype_r* )a; \ -\ - ctype_r* restrict b_cast = ( ctype_r* )b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - const pack_t schema_a = bli_auxinfo_schema_a( data ); \ - const pack_t schema_b = bli_auxinfo_schema_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 4mh method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ - /* The following gemm micro-kernel call implement one "phase" of the - 4m method: - - c = beta * c; - c_r += a_r * b_r - a_i * b_i; - c_i += a_r * b_i + a_i * b_r; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - /* ct = alpha_r * a * b; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_cast, \ - b_cast, \ - zero_r, \ - ct, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ -\ - /* How we accumulate the intermediate matrix product stored in ct - depends on (a) the schemas of A and B, and (b) the value of - beta. */ \ - if ( bli_is_ro_packed( schema_a ) && \ - bli_is_ro_packed( schema_b ) ) \ - { \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c; - c_r = c_r + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ -\ - PASTEMAC(ch,scals)( *beta, *gamma11 ); \ - PASTEMAC(chr,adds)( gamma11t, *gamma11_r ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct; - c_i = c_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_r ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct; - c_i = beta_r * c_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t, beta_r, *gamma11_r ); \ - PASTEMAC(chr,scals)( beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct; - c_i = 0; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,set0s)( *gamma11_i ); \ - } \ - } \ - } \ - else if ( ( bli_is_ro_packed( schema_a ) && \ - bli_is_io_packed( schema_b ) ) || \ - ( bli_is_io_packed( schema_a ) && \ - bli_is_ro_packed( schema_b ) ) \ - ) \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + 0; - c_i = c_i + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = 0; - c_i = ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,set0s)( *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t, *gamma11_i ); \ - } \ - } \ - } \ - else /* if ( bli_is_io_packed( schema_a ) && \ - bli_is_io_packed( schema_b ) ) */ \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r - ct; - c_i = c_i + 0; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ -\ - PASTEMAC(chr,subs)( gamma11t, *gamma11_r ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = -ct; - c_i = 0; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_r ); \ - PASTEMAC(chr,set0s)( *gamma11_i ); \ - } \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm4mh, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/ref_kernels/ind/bli_gemmtrsm1m_ref.c b/ref_kernels/ind/bli_gemmtrsm1m_ref.c index 4966d9196a..08823f0736 100644 --- a/ref_kernels/ind/bli_gemmtrsm1m_ref.c +++ b/ref_kernels/ind/bli_gemmtrsm1m_ref.c @@ -39,6 +39,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a1x, \ @@ -59,7 +61,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ PASTECH(ch,trsm_ukr_ft) \ ctrsm_vir_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, trsmkerid, cntx ); \ \ - const bool_t col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ + const bool col_pref_r = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ @@ -78,7 +80,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const dim_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ const dim_t k2 = 2 * k; \ \ @@ -98,6 +100,28 @@ void PASTEMAC3(ch,opname,arch,suf) \ ctype_r* b_use; \ inc_t rs_b_use; \ inc_t cs_b_use; \ +\ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + /* FGVZ: Should we be querying the preference of BLIS_GEMMTRSM_?_UKR + instead? */ \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : nr ); \ + const inc_t cs_ct = ( col_pref ? mr : 1 ); \ +\ + const bool use_ct = ( m < mr || n < nr ); \ +\ + ctype* restrict c11_use = c11; \ + inc_t rs_c_use = rs_c; \ + inc_t cs_c_use = cs_c; \ +\ + if ( use_ct ) \ + { \ + c11_use = ct; \ + rs_c_use = rs_ct; \ + cs_c_use = cs_ct; \ + } \ \ \ /* Handle alphas with non-zero imaginary components. */ \ @@ -113,7 +137,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ { \ bli_abort(); \ \ -/* + /* ctype_r* restrict one_r = PASTEMAC(chr,1); \ \ const inc_t ld_b = rs_b; \ @@ -125,17 +149,17 @@ void PASTEMAC3(ch,opname,arch,suf) \ b11, rs_b, cs_b, ld_b ); \ \ alpha_r = *one_r; \ -*/ \ + */ \ } \ \ \ { \ /* Set the strides for the temporary bt matrix based on the native real domain micro-kernel storage preferences. */ \ - if ( col_pref ) { rs_bt = 1; cs_bt = mr; \ - rs_bt_r = 1; cs_bt_r = mr_r; } \ - else { rs_bt = nr; cs_bt = 1; \ - rs_bt_r = nr_r; cs_bt_r = 1; } \ + if ( col_pref_r ) { rs_bt = 1; cs_bt = mr; \ + rs_bt_r = 1; cs_bt_r = mr_r; } \ + else { rs_bt = nr; cs_bt = 1; \ + rs_bt_r = nr_r; cs_bt_r = 1; } \ \ b_use = ( ctype_r* )bt; \ rs_b_use = rs_bt_r; \ @@ -153,6 +177,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ upper: bt = -1.0 * a12 * b21; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ minus_one_r, \ a1x_r, \ @@ -239,10 +265,20 @@ void PASTEMAC3(ch,opname,arch,suf) \ ( \ a11, \ b11, \ - c11, rs_c, cs_c, \ + c11_use, rs_c_use, cs_c_use, \ data, \ cntx \ ); \ +\ + if ( use_ct ) \ + { \ + PASTEMAC(ch,copys_mxn) \ + ( \ + m, n, \ + ct, rs_ct, cs_ct, \ + c11, rs_c, cs_c \ + ); \ + } \ } INSERT_GENTFUNCCO_BASIC3( gemmtrsm1m_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_L_UKR ) diff --git a/ref_kernels/ind/bli_gemmtrsm3m1_ref.c b/ref_kernels/ind/bli_gemmtrsm3m1_ref.c deleted file mode 100644 index 820a0ec2ba..0000000000 --- a/ref_kernels/ind/bli_gemmtrsm3m1_ref.c +++ /dev/null @@ -1,248 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf, trsmkerid ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a1x, \ - ctype* restrict a11, \ - ctype* restrict bx1, \ - ctype* restrict b11, \ - ctype* restrict c11, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - PASTECH(ch,trsm_ukr_ft) \ - ctrsm_vir_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, trsmkerid, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ab_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ab_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const inc_t rs_ab = 1; \ - const inc_t cs_ab = mr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a1x_r = ( ctype_r* )a1x; \ - ctype_r* restrict a1x_i = ( ctype_r* )a1x + is_a; \ - ctype_r* restrict a1x_ri = ( ctype_r* )a1x + 2*is_a; \ -\ - ctype_r* restrict bx1_r = ( ctype_r* )bx1; \ - ctype_r* restrict bx1_i = ( ctype_r* )bx1 + is_b; \ - ctype_r* restrict bx1_ri = ( ctype_r* )bx1 + 2*is_b; \ -\ - ctype_r* restrict b11_r = ( ctype_r* )b11; \ - ctype_r* restrict b11_i = ( ctype_r* )b11 + is_b; \ - ctype_r* restrict b11_ri = ( ctype_r* )b11 + 2*is_b; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype_r alpha_r = PASTEMAC(ch,real)( *alpha ); \ - ctype_r alpha_i = PASTEMAC(ch,imag)( *alpha ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t i, j; \ -\ -\ - /* Copy the contents of c to a temporary buffer ct. */ \ - if ( !PASTEMAC(chr,eq0)( alpha_i ) ) \ - { \ - /* We can handle a non-zero imaginary component on alpha, but to do - so we have to manually scale b and then use alpha == 1 for the - micro-kernel calls. */ \ - for ( i = 0; i < m; ++i ) \ - for ( j = 0; j < n; ++j ) \ - PASTEMAC(ch,scalris)( alpha_r, \ - alpha_i, \ - *(b11_r + i*rs_b + j*cs_b), \ - *(b11_i + i*rs_b + j*cs_b) ); \ -\ - /* Use alpha.r == 1.0. */ \ - alpha_r = *one_r; \ - } \ -\ -\ - /* lower: - b11.r = alpha.r * b11.r - ( + a10.r * b01.r - a10.i * b01.i ); - b11.i = alpha.r * b11.i - ( a10.ri * b01.ri - a10.r * b01.r - a10.i * b01.i ); - - upper: - b11.r = alpha.r * b11.r - ( + a12.r * b21.r - a12.i * b21.i ); - b11.i = alpha.r * b11.i - ( a12.ri * b21.ri - a12.r * b21.r - a12.i * b21.i ); */ \ -\ - bli_auxinfo_set_next_ab( a1x_i, bx1_i, data ); \ -\ - /* lower: ab.r = a10.r * b01.r; - upper: ab.r = a12.r * b21.r; */ \ - rgemm_ukr \ - ( \ - k, \ - one_r, \ - a1x_r, \ - bx1_r, \ - zero_r, \ - ab_r, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a1x_ri, bx1_ri, data ); \ -\ - /* lower: ab.i = a10.i * b01.i; - upper: ab.i = a12.i * b21.i; */ \ - rgemm_ukr \ - ( \ - k, \ - one_r, \ - a1x_i, \ - bx1_i, \ - zero_r, \ - ab_i, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* lower: b11.i = alpha.r * b11.i - a12.ri * b21.ri; - upper: b11.i = alpha.r * b11.i - a12.ri * b21.ri; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_ri, \ - bx1_ri, \ - &alpha_r, \ - b11_i, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ -\ - /* b11.r = alpha.r * b11.r - ab.r; - b11.r = b11.r + ab.i; - b11.i = b11.i + ab.r; - b11.i = b11.i + ab.i; */ \ - for ( i = 0; i < m; ++i ) \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r alphabeta_r = *(ab_r + i*rs_ab + j*cs_ab); \ - ctype_r alphabeta_i = *(ab_i + i*rs_ab + j*cs_ab); \ - ctype_r beta11_r = *(b11_r + i*rs_b + j*cs_b); \ - ctype_r beta11_i = *(b11_i + i*rs_b + j*cs_b); \ -\ - PASTEMAC(chr,scals)( alpha_r, beta11_r ); \ -\ - PASTEMAC(chr,subs)( alphabeta_r, beta11_r ); \ - PASTEMAC(chr,adds)( alphabeta_i, beta11_r ); \ - PASTEMAC(chr,adds)( alphabeta_r, beta11_i ); \ - PASTEMAC(chr,adds)( alphabeta_i, beta11_i ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(ch,copyris)( beta11_r, \ - beta11_i, \ - *(b11_r + i*rs_b + j*cs_b), \ - *(b11_i + i*rs_b + j*cs_b) ); \ -\ - /* Update the ri part of b11. */ \ - PASTEMAC(chr,add3s)( beta11_r, \ - beta11_i, \ - *(b11_ri + i*rs_b + j*cs_b) ); \ - } \ -\ -\ - /* b11 = inv(a11) * b11; - c11 = b11; */ \ - ctrsm_vir_ukr \ - ( \ - a11, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ -\ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_r after", m, n, \ - b11_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_i after", m, n, \ - b11_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b01_r", k, n, \ - b01_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b01_i", k, n, \ - b01_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_r", m, n, \ - b11_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_i", m, n, \ - b11_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNCCO_BASIC3( gemmtrsm3m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_L_UKR ) -INSERT_GENTFUNCCO_BASIC3( gemmtrsm3m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_U_UKR ) diff --git a/ref_kernels/ind/bli_gemmtrsm4m1_ref.c b/ref_kernels/ind/bli_gemmtrsm4m1_ref.c deleted file mode 100644 index 1b2205c8d7..0000000000 --- a/ref_kernels/ind/bli_gemmtrsm4m1_ref.c +++ /dev/null @@ -1,222 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf, trsmkerid ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a1x, \ - ctype* restrict a11, \ - ctype* restrict bx1, \ - ctype* restrict b11, \ - ctype* restrict c11, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - PASTECH(ch,trsm_ukr_ft) \ - ctrsm_vir_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, trsmkerid, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a1x_r = ( ctype_r* )a1x; \ - ctype_r* restrict a1x_i = ( ctype_r* )a1x + is_a; \ -\ - ctype_r* restrict bx1_r = ( ctype_r* )bx1; \ - ctype_r* restrict bx1_i = ( ctype_r* )bx1 + is_b; \ -\ - ctype_r* restrict b11_r = ( ctype_r* )b11; \ - ctype_r* restrict b11_i = ( ctype_r* )b11 + is_b; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype_r alpha_r = PASTEMAC(ch,real)( *alpha ); \ - ctype_r alpha_i = PASTEMAC(ch,imag)( *alpha ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t i, j; \ -\ -/* -printf( "gemmtrsm4m1_l_ukr: is_a = %lu is_b = %lu\n", is_a, is_b ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: a1x11p_r", m, k+m, \ - a1x_r, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: a1x11p_i", m, k+m, \ - a1x_i, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_r", k+m, n, \ - bx1_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_i", k+m, n, \ - bx1_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ - /* Copy the contents of c to a temporary buffer ct. */ \ - if ( !PASTEMAC(chr,eq0)( alpha_i ) ) \ - { \ - /* We can handle a non-zero imaginary component on alpha, but to do - so we have to manually scale b and then use alpha == 1 for the - micro-kernel calls. */ \ - for ( i = 0; i < m; ++i ) \ - for ( j = 0; j < n; ++j ) \ - PASTEMAC(ch,scalris)( alpha_r, \ - alpha_i, \ - *(b11_r + i*rs_b + j*cs_b), \ - *(b11_i + i*rs_b + j*cs_b) ); \ -\ - /* Use alpha.r == 1.0. */ \ - alpha_r = *one_r; \ - } \ -\ -\ - /* lower: b11.r = alpha.r * b11.r - ( a10.r * b01.r - a10.i * b01.i ); - b11.i = alpha.r * b11.i - ( a10.r * b01.i + a10.i * b01.r ); - - upper: b11.r = alpha.r * b11.r - ( a12.r * b21.r - a12.i * b21.i ); - b11.i = alpha.r * b11.i - ( a12.r * b21.i + a12.i * b21.r ); */ \ -\ - bli_auxinfo_set_next_ab( a1x_r, bx1_i, data ); \ -\ - /* lower: b11.r = alpha.r * b11.r - a10.r * b01.r; - upper: b11.r = alpha.r * b11.r - a12.r * b21.r; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_r, \ - bx1_r, \ - &alpha_r, \ - b11_r, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a1x_i, bx1_r, data ); \ -\ - /* lower: b11.i = alpha.r * b11.i - a10.r * b01.i; - upper: b11.i = alpha.r * b11.i - a12.r * b21.i; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_r, \ - bx1_i, \ - &alpha_r, \ - b11_i, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a1x_i, bx1_i, data ); \ -\ - /* lower: b11.i = 1.0 * b11.i - a10.i * b01.r; - upper: b11.i = 1.0 * b11.i - a12.i * b21.r; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_i, \ - bx1_r, \ - one_r, \ - b11_i, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* lower: b11.r = 1.0 * b11.r + a10.i * b01.i; - upper: b11.r = 1.0 * b11.r + a12.i * b21.i; */ \ - rgemm_ukr \ - ( \ - k, \ - one_r, \ - a1x_i, \ - bx1_i, \ - one_r, \ - b11_r, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_r post-gemm", k+m, n, \ - bx1_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_i post-gemm", k+m, n, \ - bx1_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ - /* b11 = inv(a11) * b11; - c11 = b11; */ \ - ctrsm_vir_ukr \ - ( \ - a11, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_r after", k+m, n, \ - bx1_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_i after", k+m, n, \ - bx1_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNCCO_BASIC3( gemmtrsm4m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_L_UKR ) -INSERT_GENTFUNCCO_BASIC3( gemmtrsm4m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_U_UKR ) diff --git a/ref_kernels/ind/bli_trsm1m_ref.c b/ref_kernels/ind/bli_trsm1m_ref.c index c130415415..68717f7a6c 100644 --- a/ref_kernels/ind/bli_trsm1m_ref.c +++ b/ref_kernels/ind/bli_trsm1m_ref.c @@ -36,7 +36,7 @@ #undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ +#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf, diagop ) \ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ @@ -67,7 +67,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ const inc_t ld_a = cs_a; \ const inc_t ld_b = rs_b; \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ dim_t iter, i, j, l; \ dim_t n_behind; \ @@ -134,14 +134,14 @@ void PASTEMAC3(ch,opname,arch,suf) \ beta11c_i ); \ \ /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11_r, \ + *alpha11_i, \ + beta11c_r, \ + beta11c_i ); \ \ /* Output final result to matrix c. */ \ PASTEMAC(ch,sets)( beta11c_r, beta11c_i, *gamma11 ); \ @@ -215,14 +215,14 @@ void PASTEMAC3(ch,opname,arch,suf) \ beta11c_i ); \ \ /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11_r, \ + *alpha11_i, \ + beta11c_r, \ + beta11c_i ); \ \ /* Output final result to matrix c. */ \ PASTEMAC(ch,sets)( beta11c_r, \ @@ -238,11 +238,15 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ } -INSERT_GENTFUNCCO_BASIC2( trsm1m_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +INSERT_GENTFUNCCO_BASIC3( trsm1m_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, scalris ) +#else +INSERT_GENTFUNCCO_BASIC3( trsm1m_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, invscalris ) +#endif #undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ +#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf, diagop ) \ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ @@ -273,7 +277,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ const inc_t ld_a = cs_a; \ const inc_t ld_b = rs_b; \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ dim_t iter, i, j, l; \ dim_t n_behind; \ @@ -340,14 +344,14 @@ void PASTEMAC3(ch,opname,arch,suf) \ beta11c_i ); \ \ /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11_r, \ + *alpha11_i, \ + beta11c_r, \ + beta11c_i ); \ \ /* Output final result to matrix c. */ \ PASTEMAC(ch,sets)( beta11c_r, beta11c_i, *gamma11 ); \ @@ -421,14 +425,14 @@ void PASTEMAC3(ch,opname,arch,suf) \ beta11c_i ); \ \ /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ + /* NOTE: When preinversion is enabled, the INVERSE of alpha11 + (1.0/alpha11) is stored during packing instead alpha11 so we + can multiply rather than divide. When preinversion is disabled, + alpha11 is stored and division happens below explicitly. */ \ + PASTEMAC(ch,diagop)( *alpha11_r, \ + *alpha11_i, \ + beta11c_r, \ + beta11c_i ); \ \ /* Output final result to matrix c. */ \ PASTEMAC(ch,sets)( beta11c_r, \ @@ -444,4 +448,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ } -INSERT_GENTFUNCCO_BASIC2( trsm1m_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +INSERT_GENTFUNCCO_BASIC3( trsm1m_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, scalris ) +#else +INSERT_GENTFUNCCO_BASIC3( trsm1m_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, invscalris ) +#endif diff --git a/ref_kernels/ind/bli_trsm3m1_ref.c b/ref_kernels/ind/bli_trsm3m1_ref.c deleted file mode 100644 index c24c2f4e2a..0000000000 --- a/ref_kernels/ind/bli_trsm3m1_ref.c +++ /dev/null @@ -1,283 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ - ctype_r* restrict b_ri = ( ctype_r* )b + 2*is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = iter; \ - n_behind = i; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a10t_r = a_r + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict a10t_i = a_i + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_ri = b_ri + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_r = b_r + (0 )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_i = b_i + (0 )*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a10t * B0; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_ri = b1_ri + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_r = B0_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_i = B0_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a10t * b01; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha10_r = a10t_r + (l )*cs_a; \ - ctype_r* restrict alpha10_i = a10t_i + (l )*cs_a; \ - ctype_r* restrict beta01_r = b01_r + (l )*rs_b; \ - ctype_r* restrict beta01_i = b01_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha10_r, \ - *alpha10_i, \ - *beta01_r, \ - *beta01_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ -\ - /* Update the ri part of the packed panel. */ \ - PASTEMAC(chr,add3s)( beta11c_r, \ - beta11c_i, \ - *beta11_ri ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm3m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ - ctype_r* restrict b_ri = ( ctype_r* )b + 2*is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = m - iter - 1; \ - n_behind = iter; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a12t_r = a_r + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict a12t_i = a_i + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_ri = b_ri + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_r = b_r + (i+1)*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_i = b_i + (i+1)*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a12t * B2; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_ri = b1_ri + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_r = B2_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_i = B2_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a12t * b21; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha12_r = a12t_r + (l )*cs_a; \ - ctype_r* restrict alpha12_i = a12t_i + (l )*cs_a; \ - ctype_r* restrict beta21_r = b21_r + (l )*rs_b; \ - ctype_r* restrict beta21_i = b21_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha12_r, \ - *alpha12_i, \ - *beta21_r, \ - *beta21_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ -\ - /* Update the ri part of the packed panel. */ \ - PASTEMAC(chr,add3s)( beta11c_r, \ - beta11c_i, \ - *beta11_ri ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm3m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/ref_kernels/ind/bli_trsm4m1_ref.c b/ref_kernels/ind/bli_trsm4m1_ref.c deleted file mode 100644 index 81d203e403..0000000000 --- a/ref_kernels/ind/bli_trsm4m1_ref.c +++ /dev/null @@ -1,284 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: a11p_r", m, m, \ - a_r, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: a11p_i", m, m, \ - a_i, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_r", m, n, \ - b_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_i", m, n, \ - b_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = iter; \ - n_behind = i; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a10t_r = a_r + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict a10t_i = a_i + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_r = b_r + (0 )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_i = b_i + (0 )*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a10t * B0; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_r = B0_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_i = B0_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a10t * b01; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha10_r = a10t_r + (l )*cs_a; \ - ctype_r* restrict alpha10_i = a10t_i + (l )*cs_a; \ - ctype_r* restrict beta01_r = b01_r + (l )*rs_b; \ - ctype_r* restrict beta01_i = b01_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha10_r, \ - *alpha10_i, \ - *beta01_r, \ - *beta01_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ - } \ - } \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_r after", m, n, \ - b_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_i after", m, n, \ - b_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm4m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = m - iter - 1; \ - n_behind = iter; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a12t_r = a_r + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict a12t_i = a_i + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_r = b_r + (i+1)*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_i = b_i + (i+1)*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a12t * B2; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_r = B2_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_i = B2_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a12t * b21; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha12_r = a12t_r + (l )*cs_a; \ - ctype_r* restrict alpha12_i = a12t_i + (l )*cs_a; \ - ctype_r* restrict beta21_r = b21_r + (l )*rs_b; \ - ctype_r* restrict beta21_i = b21_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha12_r, \ - *alpha12_i, \ - *beta21_r, \ - *beta21_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm4m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/sandbox/gemmlike/attic/bls_gemm_bp_var2.c b/sandbox/gemmlike/attic/bls_gemm_bp_var2.c new file mode 100644 index 0000000000..957cd57944 --- /dev/null +++ b/sandbox/gemmlike/attic/bls_gemm_bp_var2.c @@ -0,0 +1,590 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemm_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemm-like block-panel algorithm (object interface) ----------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bls_?gemm_bp_var2(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bls_,gemm_bp_var2); + +void bls_gemm_bp_var2 + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemm-like block-panel algorithm (typed interface) ------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + /* + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + */ \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + /* + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ + */ \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + /*ctype zero_local = *PASTEMAC(ch,0);*/ \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Call a wrapper to the kernel (which handles edge cases). */ \ + PASTECH2(bls_,ch,gemm_kernel) \ + ( \ + MR, \ + NR, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bls_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bls_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_bp_var2 ) +GENTFUNC( float, s, gemm_bp_var2 ) +GENTFUNC( double, d, gemm_bp_var2 ) +GENTFUNC( scomplex, c, gemm_bp_var2 ) +GENTFUNC( dcomplex, z, gemm_bp_var2 ) + +// +// -- gemm-like microkernel wrapper -------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t kc_cur, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* Infer the datatype from the ctype. */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype zero = *PASTEMAC(ch,0); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + &zero, \ + ct, rs_ct, cs_ct, \ + aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_kernel ) +GENTFUNC( float, s, gemm_kernel ) +GENTFUNC( double, d, gemm_kernel ) +GENTFUNC( scomplex, c, gemm_kernel ) +GENTFUNC( dcomplex, z, gemm_kernel ) + diff --git a/frame/3/gemm/ind/old/bli_gemm3m3_packa.c b/sandbox/gemmlike/bli_gemm_ex.c similarity index 51% rename from frame/3/gemm/ind/old/bli_gemm3m3_packa.c rename to sandbox/gemmlike/bli_gemm_ex.c index 24d575c814..96dae1a3a9 100644 --- a/frame/3/gemm/ind/old/bli_gemm3m3_packa.c +++ b/sandbox/gemmlike/bli_gemm_ex.c @@ -32,111 +32,60 @@ */ +// Given the current architecture of BLIS sandboxes, bli_gemm_ex() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented functionally identically to the +// function that it overrides in frame/3/bli_l3_oapi_ex.c. This means that +// we are forgoing the option of customizing the implementations that +// underlie bli_gemm() and bli_?gemm() (which both call bli_gemm_ex()). +// Any new code defined in this sandbox directory, however, will be +// included in the BLIS. + #include "blis.h" -void bli_gemm3m3_packa +void bli_gemm_ex ( + obj_t* alpha, obj_t* a, obj_t* b, + obj_t* beta, obj_t* c, cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread + rntm_t* rntm ) { - obj_t a_pack; - - // Make a copy of the context for each stage. - cntx_t cntx_ro = *cntx; - cntx_t cntx_io = *cntx; - cntx_t cntx_rpi = *cntx; - - // ----------------------------------------------------- - - // Initialize the context for the real-only stage. - bli_gemm3m3_cntx_stage( 0, &cntx_ro ); - - // Pack matrix the real-only part of A. - bli_l3_packm - ( - a, - &a_pack, - &cntx_ro, - cntl, - thread - ); - - // Proceed with execution using packed matrix A. - bli_gemm_int + bli_init_once(); + + // A switch to easily toggle whether we use the sandbox implementation + // of bls_gemm() as the implementation for bli_gemm(). (This allows for + // easy testing of bls_gemm() via the testsuite.) Changing the conditional + // to "0" will cause bli_gemm()/bli_gemm_ex() to *not* call the local + // sandbox implementation, though that implementation may still be called + // directly. + if ( 1 ) + { + bls_gemm_ex( alpha, a, b, beta, c, cntx, rntm ); + return; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) + alpha, a, b, beta, c, cntx, rntm, NULL ); - - // Only apply beta within the first of three subproblems. - bli_obj_scalar_reset( c ); - - // ----------------------------------------------------- - - // Initialize the context for the imag-only stage. - bli_gemm3m3_cntx_stage( 1, &cntx_io ); - - // Pack matrix the imag-only part of A. - bli_l3_packm - ( - a, - &a_pack, - &cntx_io, - cntl, - thread - ); - - // Proceed with execution using packed matrix A. - bli_gemm_int - ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) - ); - - // ----------------------------------------------------- - - // Initialize the context for the real+imag stage. - bli_gemm3m3_cntx_stage( 2, &cntx_rpi ); - - // Pack matrix the real+imag part of A. - bli_l3_packm - ( - a, - &a_pack, - &cntx_rpi, - cntl, - thread - ); - - // Proceed with execution using packed matrix A. - bli_gemm_int - ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) - ); - } diff --git a/sandbox/gemmlike/bli_sandbox.h b/sandbox/gemmlike/bli_sandbox.h new file mode 100644 index 0000000000..f3782b3dbc --- /dev/null +++ b/sandbox/gemmlike/bli_sandbox.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SANDBOX_H +#define BLIS_SANDBOX_H + +// NOTE: This header is the only header required to be present in the sandbox +// implementation directory. + +// This header should contain (or #include) any definitions that must be +// folded into blis.h. Typically, it will remain empty since any header +// definitions specific to the sandbox implementation will not need to be +// made available to applications (or the framework) during compilation. + +#include "bls_gemm.h" +#include "bls_gemm_check.h" +#include "bls_gemm_var.h" + +#include "bls_l3_packm_a.h" +#include "bls_l3_packm_b.h" +#include "bls_l3_packm_var.h" + +#include "bls_packm_cxk.h" + +#include "bls_l3_decor.h" + + +#endif diff --git a/sandbox/gemmlike/bls_gemm.c b/sandbox/gemmlike/bls_gemm.c new file mode 100644 index 0000000000..f2f8b7e257 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm.c @@ -0,0 +1,278 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Define the gemm-like operation's object API ------------------------------ +// + +void bls_gemm + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + bls_gemm_ex + ( + alpha, + a, + b, + beta, + c, + NULL, + NULL + ); +} + +void bls_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bls_gemm_check( alpha, a, b, beta, c, cntx ); + + // -- bli_gemm_front() ----------------------------------------------------- + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + + // Induce a transposition of A if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + + // Induce a transposition of B if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &b_local ) ) + { + bli_obj_induce_trans( &b_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &b_local ); + } + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // Spawn threads (if applicable), where bls_gemm_int() is the thread entry + // point function for each thread. This also begins the process of creating + // the thrinfo_t tree, which contains thread communicators. + bls_l3_thread_decorator + ( + bls_gemm_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + &b_local, + beta, + &c_local, + cntx, + rntm + ); +} + +// +// -- Define the gemm-like operation's thread entry point ---------------------- +// + +void bls_gemm_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // In this function, we choose the gemm implementation that is executed + // on each thread. + + // Call the block-panel algorithm. + bls_gemm_bp_var1 + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm, + thread + ); +} + +// +// -- Define the gemm-like operation's typed API ------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Determine the datatype (e.g. BLIS_FLOAT, BLIS_DOUBLE, etc.) based on + the macro parameter 'ch' (e.g. s, d, etc). */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao, ao, bo, betao, co; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + /* Adjust the dimensions of matrices A and B according to the transa and + transb parameters. */ \ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + /* Create bufferless scalar objects and attach the provided scalar pointers + to those scalar objects. */ \ + bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ +\ + /* Create bufferless matrix objects and attach the provided matrix pointers + to those matrix objects. */ \ + bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + /* Set the transposition/conjugation properties of the objects for matrices + A and B. */ \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + /* Call the object interface. */ \ + PASTECH(bls_,opname) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co \ + ); \ +} + +//INSERT_GENTFUNC_BASIC0( gemm ) +GENTFUNC( float, s, gemm ) +GENTFUNC( double, d, gemm ) +GENTFUNC( scomplex, c, gemm ) +GENTFUNC( dcomplex, z, gemm ) + diff --git a/frame/3/gemm/bli_gemm_packab.c b/sandbox/gemmlike/bls_gemm.h similarity index 62% rename from frame/3/gemm/bli_gemm_packab.c rename to sandbox/gemmlike/bls_gemm.h index a15192994e..b296ac1c0f 100644 --- a/frame/3/gemm/bli_gemm_packab.c +++ b/sandbox/gemmlike/bls_gemm.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,85 +32,70 @@ */ -#include "blis.h" +// +// -- Prototype the gemm-like operation's object API --------------------------- +// -void bli_gemm_packa +void bls_gemm ( + obj_t* alpha, obj_t* a, obj_t* b, + obj_t* beta, + obj_t* c + ); + +void bls_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, obj_t* c, cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - obj_t a_pack; - - // Pack matrix A according to the control tree node. - bli_l3_packm - ( - a, - &a_pack, - cntx, - rntm, - cntl, - thread - ); + rntm_t* rntm + ); - // Proceed with execution using packed matrix A. - bli_gemm_int - ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - rntm, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) - ); -} +// +// -- Prototype the gemm-like operation's thread entry point ------------------- +// -// ----------------------------------------------------------------------------- - -void bli_gemm_packb +void bls_gemm_int ( + obj_t* alpha, obj_t* a, obj_t* b, + obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm, - cntl_t* cntl, thrinfo_t* thread - ) -{ - obj_t b_pack; + ); + +// +// -- Prototype the gemm-like operation's typed API ---------------------------- +// - // Pack matrix B according to the control tree node. - bli_l3_packm - ( - b, - &b_pack, - cntx, - rntm, - cntl, - thread - ); +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ); - // Proceed with execution using packed matrix B. - bli_gemm_int - ( - &BLIS_ONE, - a, - &b_pack, - &BLIS_ONE, - c, - cntx, - rntm, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) - ); -} +//INSERT_GENTPROT_BASIC0( gemm ) +GENTPROT( float, s, gemm ) +GENTPROT( double, d, gemm ) +GENTPROT( scomplex, c, gemm ) +GENTPROT( dcomplex, z, gemm ) diff --git a/sandbox/gemmlike/bls_gemm_bp_var1.c b/sandbox/gemmlike/bls_gemm_bp_var1.c new file mode 100644 index 0000000000..62dc462d51 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_bp_var1.c @@ -0,0 +1,479 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemm_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemm-like block-panel algorithm (object interface) ----------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bls_?gemm_bp_var1(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bls_,gemm_bp_var1); + +void bls_gemm_bp_var1 + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemm-like block-panel algorithm (typed interface) ------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bls_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bls_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_bp_var1 ) +GENTFUNC( float, s, gemm_bp_var1 ) +GENTFUNC( double, d, gemm_bp_var1 ) +GENTFUNC( scomplex, c, gemm_bp_var1 ) +GENTFUNC( dcomplex, z, gemm_bp_var1 ) + diff --git a/sandbox/gemmlike/bls_gemm_check.c b/sandbox/gemmlike/bls_gemm_check.c new file mode 100644 index 0000000000..3690173387 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_check.c @@ -0,0 +1,117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bls_gemm_check + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + //bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); + + err_t e_val; + + // Check object datatypes. + + e_val = bli_check_noninteger_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_noninteger_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( c ); + bli_check_error_code( e_val ); + + // Check scalar/vector/matrix type. + + e_val = bli_check_scalar_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_scalar_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( c ); + bli_check_error_code( e_val ); + + // Check object buffers (for non-NULLness). + + e_val = bli_check_object_buffer( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( c ); + bli_check_error_code( e_val ); + + // Check object dimensions. + + e_val = bli_check_level3_dims( a, b, c ); + bli_check_error_code( e_val ); + + // Check for consistent datatypes. + // NOTE: We only perform these tests when mixed datatype support is + // disabled. + + e_val = bli_check_consistent_object_datatypes( c, a ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, b ); + bli_check_error_code( e_val ); +} + diff --git a/frame/3/syr2k/bli_syr2k_front.h b/sandbox/gemmlike/bls_gemm_check.h similarity index 94% rename from frame/3/syr2k/bli_syr2k_front.h rename to sandbox/gemmlike/bls_gemm_check.h index 767bb6ee11..8b97069911 100644 --- a/frame/3/syr2k/bli_syr2k_front.h +++ b/sandbox/gemmlike/bls_gemm_check.h @@ -32,14 +32,18 @@ */ -void bli_syr2k_front + +// +// Prototype object-based check functions. +// + +void bls_gemm_check ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ); + cntx_t* cntx + ); + diff --git a/frame/1m/packm/bli_packm_var.h b/sandbox/gemmlike/bls_gemm_var.h similarity index 51% rename from frame/1m/packm/bli_packm_var.h rename to sandbox/gemmlike/bls_gemm_var.h index 6c11b19abc..7c515f8c39 100644 --- a/frame/1m/packm/bli_packm_var.h +++ b/sandbox/gemmlike/bls_gemm_var.h @@ -4,8 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2021, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,78 +32,86 @@ */ + // -// Prototype object-based interfaces. +// Prototype the object-based variant interfaces. // #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +void PASTECH(bls_,opname) \ ( \ - obj_t* c, \ - obj_t* p, \ - cntx_t* cntx, \ - cntl_t* cntl, \ - thrinfo_t* t \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ ); -GENPROT( packm_unb_var1 ) -GENPROT( packm_blk_var1 ) +GENPROT( gemm_bp_var1 ) + // -// Prototype BLAS-like interfaces with void pointer operands. +// Prototype the typed variant interfaces. // #undef GENTPROT #define GENTPROT( ctype, ch, varname ) \ \ -void PASTEMAC(ch,varname) \ +void PASTECH2(bls_,ch,varname) \ ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - trans_t transc, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - cntx_t* cntx \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ ); -INSERT_GENTPROT_BASIC0( packm_unb_var1 ) +//INSERT_GENTPROT_BASIC0( gemm_bp_var1 ) +GENTPROT( float, s, gemm_bp_var1 ) +GENTPROT( double, d, gemm_bp_var1 ) +GENTPROT( scomplex, c, gemm_bp_var1 ) +GENTPROT( dcomplex, z, gemm_bp_var1 ) + + +// +// Prototype the typed kernel interfaces. +// #undef GENTPROT #define GENTPROT( ctype, ch, varname ) \ \ -void PASTEMAC(ch,varname) \ +void PASTECH2(bls_,ch,varname) \ ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - trans_t transc, \ - pack_t schema, \ - bool_t invdiag, \ - bool_t revifup, \ - bool_t reviflo, \ - dim_t m, \ - dim_t n, \ - dim_t m_max, \ - dim_t n_max, \ - void* kappa, \ - void* c, inc_t rs_c, inc_t cs_c, \ - void* p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - dim_t pd_p, inc_t ps_p, \ - void* packm_ker, \ - cntx_t* cntx, \ - thrinfo_t* thread \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( packm_blk_var1 ) +//INSERT_GENTPROT_BASIC0( gemm_kernel ) +GENTPROT( float, s, gemm_kernel ) +GENTPROT( double, d, gemm_kernel ) +GENTPROT( scomplex, c, gemm_kernel ) +GENTPROT( dcomplex, z, gemm_kernel ) diff --git a/sandbox/gemmlike/bls_l3_packm_a.c b/sandbox/gemmlike/bls_l3_packm_a.c new file mode 100644 index 0000000000..0dcc531fdb --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_a.c @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to blocks of A. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_A_BLOCK; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_a ) +GENTFUNC( float, s, packm_init_mem_a ) +GENTFUNC( double, d, packm_init_mem_a ) +GENTFUNC( scomplex, c, packm_init_mem_a ) +GENTFUNC( dcomplex, z, packm_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_a ) +GENTFUNC( float, s, packm_finalize_mem_a ) +GENTFUNC( double, d, packm_finalize_mem_a ) +GENTFUNC( scomplex, c, packm_finalize_mem_a ) +GENTFUNC( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + { \ + /* Pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_a ) +GENTFUNC( float, s, packm_init_a ) +GENTFUNC( double, d, packm_init_a ) +GENTFUNC( scomplex, c, packm_init_a ) +GENTFUNC( dcomplex, z, packm_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bls_,ch,packm_init_mem_a) \ + ( \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. */ \ + PASTECH2(bls_,ch,packm_init_a) \ + ( \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix A to the destination buffer chosen above. Here, the packed + matrix is stored to column-stored MR x k micropanels. */ \ + PASTECH2(bls_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_a ) +GENTFUNC( float, s, packm_a ) +GENTFUNC( double, d, packm_a ) +GENTFUNC( scomplex, c, packm_a ) +GENTFUNC( dcomplex, z, packm_a ) + diff --git a/sandbox/gemmlike/bls_l3_packm_a.h b/sandbox/gemmlike/bls_l3_packm_a.h new file mode 100644 index 0000000000..201a24efae --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_a.h @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_a ) +GENTPROT( float, s, packm_init_mem_a ) +GENTPROT( double, d, packm_init_mem_a ) +GENTPROT( scomplex, c, packm_init_mem_a ) +GENTPROT( dcomplex, z, packm_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_a ) +GENTPROT( float, s, packm_finalize_mem_a ) +GENTPROT( double, d, packm_finalize_mem_a ) +GENTPROT( scomplex, c, packm_finalize_mem_a ) +GENTPROT( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_a ) +GENTPROT( float, s, packm_init_a ) +GENTPROT( double, d, packm_init_a ) +GENTPROT( scomplex, c, packm_init_a ) +GENTPROT( dcomplex, z, packm_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_a ) +GENTPROT( float, s, packm_a ) +GENTPROT( double, d, packm_a ) +GENTPROT( scomplex, c, packm_a ) +GENTPROT( dcomplex, z, packm_a ) + diff --git a/sandbox/gemmlike/bls_l3_packm_b.c b/sandbox/gemmlike/bls_l3_packm_b.c new file mode 100644 index 0000000000..9d563109a6 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_b.c @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to panels of B. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_B_PANEL; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_b ) +GENTFUNC( float, s, packm_init_mem_b ) +GENTFUNC( double, d, packm_init_mem_b ) +GENTFUNC( scomplex, c, packm_init_mem_b ) +GENTFUNC( dcomplex, z, packm_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_b ) +GENTFUNC( float, s, packm_finalize_mem_b ) +GENTFUNC( double, d, packm_finalize_mem_b ) +GENTFUNC( scomplex, c, packm_finalize_mem_b ) +GENTFUNC( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + { \ + /* Pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_b ) +GENTFUNC( float, s, packm_init_b ) +GENTFUNC( double, d, packm_init_b ) +GENTFUNC( scomplex, c, packm_init_b ) +GENTFUNC( dcomplex, z, packm_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bls_,ch,packm_init_mem_b) \ + ( \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. */ \ + PASTECH2(bls_,ch,packm_init_b) \ + ( \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix B to the destination buffer chosen above. Here, the packed + matrix is stored to row-stored k x NR micropanels. */ \ + PASTECH2(bls_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_b ) +GENTFUNC( float, s, packm_b ) +GENTFUNC( double, d, packm_b ) +GENTFUNC( scomplex, c, packm_b ) +GENTFUNC( dcomplex, z, packm_b ) + diff --git a/sandbox/gemmlike/bls_l3_packm_b.h b/sandbox/gemmlike/bls_l3_packm_b.h new file mode 100644 index 0000000000..728d21aed5 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_b.h @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_b ) +GENTPROT( float, s, packm_init_mem_b ) +GENTPROT( double, d, packm_init_mem_b ) +GENTPROT( scomplex, c, packm_init_mem_b ) +GENTPROT( dcomplex, z, packm_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_b ) +GENTPROT( float, s, packm_finalize_mem_b ) +GENTPROT( double, d, packm_finalize_mem_b ) +GENTPROT( scomplex, c, packm_finalize_mem_b ) +GENTPROT( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_b ) +GENTPROT( float, s, packm_init_b ) +GENTPROT( double, d, packm_init_b ) +GENTPROT( scomplex, c, packm_init_b ) +GENTPROT( dcomplex, z, packm_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_b ) +GENTPROT( float, s, packm_b ) +GENTPROT( double, d, packm_b ) +GENTPROT( scomplex, c, packm_b ) +GENTPROT( dcomplex, z, packm_b ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var.h b/sandbox/gemmlike/bls_l3_packm_var.h new file mode 100644 index 0000000000..98300536bc --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var.h @@ -0,0 +1,74 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( packm_var1 ) +GENTPROT( float, s, packm_var1 ) +GENTPROT( double, d, packm_var1 ) +GENTPROT( scomplex, c, packm_var1 ) +GENTPROT( dcomplex, z, packm_var1 ) + +//INSERT_GENTPROT_BASIC0( packm_var2 ) +GENTPROT( float, s, packm_var2 ) +GENTPROT( double, d, packm_var2 ) +GENTPROT( scomplex, c, packm_var2 ) +GENTPROT( dcomplex, z, packm_var2 ) + +//INSERT_GENTPROT_BASIC0( packm_var3 ) +GENTPROT( float, s, packm_var3 ) +GENTPROT( double, d, packm_var3 ) +GENTPROT( scomplex, c, packm_var3 ) +GENTPROT( dcomplex, z, packm_var3 ) diff --git a/sandbox/gemmlike/bls_l3_packm_var1.c b/sandbox/gemmlike/bls_l3_packm_var1.c new file mode 100644 index 0000000000..c0649a9ec4 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var1.c @@ -0,0 +1,193 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 1 provides basic support for packing by calling packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTECH2(bls_,ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim, \ + panel_dim_max, \ + panel_len, \ + panel_len_max, \ + kappa_cast, \ + c_use, incc, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var1 ) +GENTFUNC( double, d, packm_var1 ) +GENTFUNC( scomplex, c, packm_var1 ) +GENTFUNC( dcomplex, z, packm_var1 ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var2.c b/sandbox/gemmlike/bls_l3_packm_var2.c new file mode 100644 index 0000000000..8d2b90cac1 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var2.c @@ -0,0 +1,244 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 2 is similar to variant 1, but inlines the contents of packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + bli_abort(); \ +\ + /* Perform the packing, taking conjc into account. */ \ + if ( bli_is_conj( conjc ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* cli = c_use + (l )*ldc + (i )*incc; \ + ctype* pli = p_use + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *cli, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* cli = c_use + (l )*ldc + (i )*incc; \ + ctype* pli = p_use + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *cli, *pli ); \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p_use + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p_use + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var2 ) +GENTFUNC( double, d, packm_var2 ) +GENTFUNC( scomplex, c, packm_var2 ) +GENTFUNC( dcomplex, z, packm_var2 ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var3.c b/sandbox/gemmlike/bls_l3_packm_var3.c new file mode 100644 index 0000000000..5ea80ff424 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var3.c @@ -0,0 +1,200 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 3 is similar to variant 1, except that it parallelizes packing +// along the k dimension. (Our current hypothesis is that this method of +// parallelizing the operation may perform better on some NUMA systems.) +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t pr_start, pr_end; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. */ \ + bli_thread_range_sub( thread, panel_len, 1, FALSE, &pr_start, &pr_end ); \ +\ + /* Define instances of panel_len and panel_len_max that are specific to + the local thread. */ \ + dim_t panel_len_loc = pr_end - pr_start; \ + dim_t panel_len_max_loc = panel_len_loc; \ +\ + /* If panel_len_max > panel_len, then there are some columns in p that + need to be zeroed. Of course, only the last thread will be responsible + for this edge region. */ \ + dim_t panel_len_zero = panel_len_max - panel_len; \ + if ( tid == nt - 1 ) panel_len_max_loc += panel_len_zero; \ +\ + /* Shift the pointer for c and p to the appropriate locations within the + first micropanel. */ \ + dim_t off_loc = pr_start; \ + ctype* restrict c_begin_loc = c_cast + off_loc * ldc; \ + ctype* restrict p_begin_loc = p_cast + off_loc * ldp; \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_use = c_begin_loc + (ic )*incc; \ + ctype* restrict p_use = p_begin_loc + (it )*ps_p; \ +\ + { \ + PASTECH2(bls_,ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim, \ + panel_dim_max, \ + panel_len_loc, \ + panel_len_max_loc, \ + kappa_cast, \ + c_use, incc, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var3 ) +GENTFUNC( float, s, packm_var3 ) +GENTFUNC( double, d, packm_var3 ) +GENTFUNC( scomplex, c, packm_var3 ) +GENTFUNC( dcomplex, z, packm_var3 ) + +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var3: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var3: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ + diff --git a/frame/1m/packm/bli_packm_cxk_rih.c b/sandbox/gemmlike/bls_packm_cxk.c similarity index 63% rename from frame/1m/packm/bli_packm_cxk_rih.c rename to sandbox/gemmlike/bls_packm_cxk.c index 1f2c9f240a..ca11c207c0 100644 --- a/frame/1m/packm/bli_packm_cxk_rih.c +++ b/sandbox/gemmlike/bls_packm_cxk.c @@ -34,10 +34,10 @@ #include "blis.h" -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +void PASTECH2(bls_,ch,opname) \ ( \ conj_t conja, \ pack_t schema, \ @@ -66,7 +66,10 @@ void PASTEMAC(ch,opname) \ \ /* If there exists a kernel implementation for the micro-panel dimension provided, we invoke the implementation. Otherwise, we use scal2m. */ \ - if ( 0 && f != NULL ) \ + /* NOTE: We've disabled calling packm micro-kernels from the context for + this implementation. To re-enable, change FALSE to TRUE in the + conditional below. */ \ + if ( f != NULL && FALSE ) \ { \ f \ ( \ @@ -83,69 +86,76 @@ void PASTEMAC(ch,opname) \ } \ else \ { \ - /* Treat the micro-panel as panel_dim x panel_len and column-stored - (unit row stride). */ \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa ) ) \ + bli_abort(); \ \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - panel_dim, \ - panel_len, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *ali, *pli ); \ + } \ + } \ + } \ \ /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ - if ( panel_dim != panel_dim_max ) \ + if ( panel_dim < panel_dim_max ) \ { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = panel_dim; \ - const dim_t m_edge = panel_dim_max - i; \ - const dim_t n_edge = panel_len_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p + (i )*1; \ \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ + PASTEMAC(ch,set0s_mxn) \ ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ m_edge, \ n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ + p_edge, 1, ldp \ ); \ } \ \ /* If panel_len < panel_len_max, then we zero those unused columns. */ \ - if ( panel_len != panel_len_max ) \ + if ( panel_len < panel_len_max ) \ { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = panel_len; \ - const dim_t m_edge = panel_dim_max; \ - const dim_t n_edge = panel_len_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p + (j )*ldp; \ \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ + PASTEMAC(ch,set0s_mxn) \ ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ m_edge, \ n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ + p_edge, 1, ldp \ ); \ } \ } \ } -INSERT_GENTFUNCCO_BASIC0( packm_cxk_rih ) +//INSERT_GENTFUNC_BASIC0( packm_cxk ) +GENTFUNC( float, s, packm_cxk ) +GENTFUNC( double, d, packm_cxk ) +GENTFUNC( scomplex, c, packm_cxk ) +GENTFUNC( dcomplex, z, packm_cxk ) diff --git a/frame/1m/packm/bli_packm_cxk_rih.h b/sandbox/gemmlike/bls_packm_cxk.h similarity index 88% rename from frame/1m/packm/bli_packm_cxk_rih.h rename to sandbox/gemmlike/bls_packm_cxk.h index c1d2ba9fe3..f6582d64a7 100644 --- a/frame/1m/packm/bli_packm_cxk_rih.h +++ b/sandbox/gemmlike/bls_packm_cxk.h @@ -33,10 +33,10 @@ */ -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ \ -void PASTEMAC(ch,varname) \ +void PASTECH2(bls_,ch,varname) \ ( \ conj_t conja, \ pack_t schema, \ @@ -50,5 +50,9 @@ void PASTEMAC(ch,varname) \ cntx_t* cntx \ ); -INSERT_GENTPROTCO_BASIC0( packm_cxk_rih ) +//INSERT_GENTPROT_BASIC0( packm_cxk ) +GENTPROT( float, s, packm_cxk ) +GENTPROT( double, d, packm_cxk ) +GENTPROT( scomplex, c, packm_cxk ) +GENTPROT( dcomplex, z, packm_cxk ) diff --git a/sandbox/gemmlike/thread/bls_l3_decor.h b/sandbox/gemmlike/thread/bls_l3_decor.h new file mode 100644 index 0000000000..bb8a95bb46 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor.h @@ -0,0 +1,73 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_H +#define BLIS_SBX_L3_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef void (*l3sbxint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading. +#include "bls_l3_decor_single.h" +#include "bls_l3_decor_openmp.h" +#include "bls_l3_decor_pthreads.h" + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.c b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c new file mode 100644 index 0000000000..bf0d4d8bcd --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c @@ -0,0 +1,138 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy thread entry function, which is needed in the pthreads +// version, so that when building Windows DLLs (with OpenMP enabled or with +// no multithreading) we don't risk having an unresolved symbol. +void* bls_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.h b/sandbox/gemmlike/thread/bls_l3_decor_openmp.h new file mode 100644 index 0000000000..9c956d7c36 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_OPENMP_H +#define BLIS_SBX_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c new file mode 100644 index 0000000000..ff723a4ce4 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c @@ -0,0 +1,215 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3sbxint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point function for additional threads. +void* bls_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3sbxint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t r_val; + + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bls_l3_thread_entry, &datas[tid] ); + else + bls_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h new file mode 100644 index 0000000000..ef5c3bad45 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_PTHREADS_H +#define BLIS_SBX_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bls_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.c b/sandbox/gemmlike/thread/bls_l3_decor_single.c new file mode 100644 index 0000000000..8bb04817fb --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.c @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_pba_rntm_set_pba( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.h b/sandbox/gemmlike/thread/bls_l3_decor_single.h new file mode 100644 index 0000000000..211a43a894 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_SINGLE_H +#define BLIS_SBX_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/sandbox/ref99/oapi/bli_gemmnat.c b/sandbox/old/ref99/bli_gemmnat.c similarity index 78% rename from sandbox/ref99/oapi/bli_gemmnat.c rename to sandbox/old/ref99/bli_gemmnat.c index 865c7cff42..399f31e216 100644 --- a/sandbox/ref99/oapi/bli_gemmnat.c +++ b/sandbox/old/ref99/bli_gemmnat.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,14 +56,19 @@ void bli_gemmnat { bli_init_once(); - // Obtain a valid native context from the gks if necessary. + // Obtain a valid (native) context from the gks if necessary. if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - // Initialize a local runtime object if necessary. + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_l; - if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } // Invoke the operation's front end. - blx_gemm_front( alpha, a, b, beta, c, cntx, rntm, NULL ); + //blx_gemm_front( alpha, a, b, beta, c, cntx, rntm, NULL ); + blx_gemm_ref_var2( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + BLIS_XXX, cntx, rntm, NULL ); } diff --git a/sandbox/ref99/include/bli_sandbox.h b/sandbox/old/ref99/bli_sandbox.h similarity index 100% rename from sandbox/ref99/include/bli_sandbox.h rename to sandbox/old/ref99/bli_sandbox.h diff --git a/sandbox/ref99/include/blix.h b/sandbox/old/ref99/blix.h similarity index 98% rename from sandbox/ref99/include/blix.h rename to sandbox/old/ref99/blix.h index 29f357cfd5..44e231d96d 100644 --- a/sandbox/ref99/include/blix.h +++ b/sandbox/old/ref99/blix.h @@ -39,7 +39,7 @@ // we #include any headers that would define prototypes or types that are // needed by the ref99 sandbox source code. -#include "blx_gemm.h" +#include "blx_gemm_ref_var2.h" #endif diff --git a/sandbox/old/ref99/blx_gemm_ref_var2.c b/sandbox/old/ref99/blx_gemm_ref_var2.c new file mode 100644 index 0000000000..b45d076356 --- /dev/null +++ b/sandbox/old/ref99/blx_gemm_ref_var2.c @@ -0,0 +1,361 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "blix.h" + +#define FUNCPTR_T gemmsup_fp + +typedef void (*FUNCPTR_T) + ( + bool packa, + bool packb, + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t eff_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- var2 --------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var2,gemm_ref_var2); + +void blx_gemm_ref_var2 + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const bool packa = bli_rntm_pack_a( rntm ); + const bool packb = bli_rntm_pack_b( rntm ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var2[dt]; + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + packa, + packb, + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm, + thread + ); + } + else + { + bli_abort(); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + bool packa, \ + bool packb, \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + /* Query the context for various blocksizes. NOTE: We query the + regular blocksizes since the sup blocksizes are not guaranteed + to have default values. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ + const inc_t jrstep_b = cs_b * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ + const inc_t irstep_a = rs_a * MR; \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of cache lines between the cores' caches. */ \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < n; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= n - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < k; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= k - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + /*bli_auxinfo_set_ps_b( ps_b_use, &aux );*/ \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < m; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= m - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + /*bli_auxinfo_set_ps_a( ps_a_use, &aux );*/ \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += 1 ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc + j * jrstep_b; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + const dim_t ir_left = mc_cur % MR; \ +\ + /* Loop over the m dimension (MR columns at a time). */ \ + for ( dim_t i = 0; i < ir_iter; i += 1 ) \ + { \ + const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic + i * irstep_a; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + /* + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ + */ \ +\ + /* Invoke the kernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a, cs_a, \ + b_jr, rs_b, cs_b, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemm_ref_var2 ) + diff --git a/sandbox/old/ref99/blx_gemm_ref_var2.h b/sandbox/old/ref99/blx_gemm_ref_var2.h new file mode 100644 index 0000000000..e188f8b4b4 --- /dev/null +++ b/sandbox/old/ref99/blx_gemm_ref_var2.h @@ -0,0 +1,73 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +void blx_gemm_ref_var2 + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + bool packa, \ + bool packb, \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +INSERT_GENTPROT_BASIC0( gemm_ref_var2 ) + diff --git a/sandbox/ref99/base/blx_blksz.c b/sandbox/old/ref99/old/base/blx_blksz.c similarity index 100% rename from sandbox/ref99/base/blx_blksz.c rename to sandbox/old/ref99/old/base/blx_blksz.c diff --git a/sandbox/ref99/base/blx_blksz.h b/sandbox/old/ref99/old/base/blx_blksz.h similarity index 100% rename from sandbox/ref99/base/blx_blksz.h rename to sandbox/old/ref99/old/base/blx_blksz.h diff --git a/sandbox/ref99/blx_gemm.h b/sandbox/old/ref99/old/blx_gemm.h similarity index 100% rename from sandbox/ref99/blx_gemm.h rename to sandbox/old/ref99/old/blx_gemm.h diff --git a/sandbox/ref99/blx_gemm_front.c b/sandbox/old/ref99/old/blx_gemm_front.c similarity index 98% rename from sandbox/ref99/blx_gemm_front.c rename to sandbox/old/ref99/old/blx_gemm_front.c index bb6ba4a8d1..399f750a5c 100644 --- a/sandbox/ref99/blx_gemm_front.c +++ b/sandbox/old/ref99/old/blx_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/blx_gemm_front.h b/sandbox/old/ref99/old/blx_gemm_front.h similarity index 100% rename from sandbox/ref99/blx_gemm_front.h rename to sandbox/old/ref99/old/blx_gemm_front.h diff --git a/sandbox/ref99/blx_gemm_int.c b/sandbox/old/ref99/old/blx_gemm_int.c similarity index 97% rename from sandbox/ref99/blx_gemm_int.c rename to sandbox/old/ref99/old/blx_gemm_int.c index c807fc76e7..525f72d5d9 100644 --- a/sandbox/ref99/blx_gemm_int.c +++ b/sandbox/old/ref99/old/blx_gemm_int.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/blx_gemm_int.h b/sandbox/old/ref99/old/blx_gemm_int.h similarity index 100% rename from sandbox/ref99/blx_gemm_int.h rename to sandbox/old/ref99/old/blx_gemm_int.h diff --git a/sandbox/ref99/cntl/blx_gemm_cntl.c b/sandbox/old/ref99/old/cntl/blx_gemm_cntl.c similarity index 85% rename from sandbox/ref99/cntl/blx_gemm_cntl.c rename to sandbox/old/ref99/old/cntl/blx_gemm_cntl.c index c40410c6fc..d7b4c69495 100644 --- a/sandbox/ref99/cntl/blx_gemm_cntl.c +++ b/sandbox/old/ref99/old/cntl/blx_gemm_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,26 +38,28 @@ cntl_t* blx_gemm_cntl_create ( - opid_t family, - pack_t schema_a, - pack_t schema_b + rntm_t* rntm, + opid_t family, + pack_t schema_a, + pack_t schema_b ) { - return blx_gemmbp_cntl_create( family, schema_a, schema_b ); + return blx_gemmbp_cntl_create( rntm, family, schema_a, schema_b ); } // ----------------------------------------------------------------------------- cntl_t* blx_gemmbp_cntl_create ( - opid_t family, - pack_t schema_a, - pack_t schema_b + rntm_t* rntm, + opid_t family, + pack_t schema_a, + pack_t schema_b ) { - void* macro_kernel_fp; - void* packa_fp; - void* packb_fp; + void_fp macro_kernel_fp; + void_fp packa_fp; + void_fp packb_fp; macro_kernel_fp = blx_gemm_ker_var2; @@ -67,6 +69,7 @@ cntl_t* blx_gemmbp_cntl_create // Create two nodes for the macro-kernel. cntl_t* gemm_cntl_bu_ke = blx_gemm_cntl_create_node ( + rntm, // the thread's runtime structure family, // the operation family BLIS_MR, // needed for bli_thrinfo_rgrow() NULL, // variant function pointer not used @@ -75,6 +78,7 @@ cntl_t* blx_gemmbp_cntl_create cntl_t* gemm_cntl_bp_bu = blx_gemm_cntl_create_node ( + rntm, // the thread's runtime structure family, BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow() macro_kernel_fp, @@ -84,6 +88,7 @@ cntl_t* blx_gemmbp_cntl_create // Create a node for packing matrix A. cntl_t* gemm_cntl_packa = blx_packm_cntl_create_node ( + rntm, blx_gemm_packa, // pack the left-hand operand packa_fp, BLIS_MR, @@ -99,6 +104,7 @@ cntl_t* blx_gemmbp_cntl_create // Create a node for partitioning the m dimension by MC. cntl_t* gemm_cntl_op_bp = blx_gemm_cntl_create_node ( + rntm, family, BLIS_MC, blx_gemm_blk_var1, @@ -108,6 +114,7 @@ cntl_t* blx_gemmbp_cntl_create // Create a node for packing matrix B. cntl_t* gemm_cntl_packb = blx_packm_cntl_create_node ( + rntm, blx_gemm_packb, // pack the right-hand operand packb_fp, BLIS_KR, @@ -123,6 +130,7 @@ cntl_t* blx_gemmbp_cntl_create // Create a node for partitioning the k dimension by KC. cntl_t* gemm_cntl_mm_op = blx_gemm_cntl_create_node ( + rntm, family, BLIS_KC, blx_gemm_blk_var3, @@ -132,6 +140,7 @@ cntl_t* blx_gemmbp_cntl_create // Create a node for partitioning the n dimension by NC. cntl_t* gemm_cntl_vl_mm = blx_gemm_cntl_create_node ( + rntm, family, BLIS_NC, blx_gemm_blk_var2, @@ -145,23 +154,25 @@ cntl_t* blx_gemmbp_cntl_create void blx_gemm_cntl_free ( - cntl_t* cntl, + rntm_t* rntm, + cntl_t* cntl, thrinfo_t* thread ) { - bli_cntl_free( cntl, thread ); + bli_cntl_free( rntm, cntl, thread ); } // ----------------------------------------------------------------------------- cntl_t* blx_gemm_cntl_create_node ( + rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ) { - return bli_cntl_create_node( family, bszid, var_func, NULL, sub_node ); + return bli_cntl_create_node( rntm, family, bszid, var_func, NULL, sub_node ); } diff --git a/sandbox/ref99/cntl/blx_gemm_cntl.h b/sandbox/old/ref99/old/cntl/blx_gemm_cntl.h similarity index 88% rename from sandbox/ref99/cntl/blx_gemm_cntl.h rename to sandbox/old/ref99/old/cntl/blx_gemm_cntl.h index 80c26b8ac2..94d872ae5a 100644 --- a/sandbox/ref99/cntl/blx_gemm_cntl.h +++ b/sandbox/old/ref99/old/cntl/blx_gemm_cntl.h @@ -34,25 +34,28 @@ cntl_t* blx_gemm_cntl_create ( - opid_t family, - pack_t schema_a, - pack_t schema_b + rntm_t* rntm, + opid_t family, + pack_t schema_a, + pack_t schema_b ); // ----------------------------------------------------------------------------- cntl_t* blx_gemmbp_cntl_create ( - opid_t family, - pack_t schema_a, - pack_t schema_b + rntm_t* rntm, + opid_t family, + pack_t schema_a, + pack_t schema_b ); // ----------------------------------------------------------------------------- void blx_gemm_cntl_free ( - cntl_t* cntl, + rntm_t* rntm, + cntl_t* cntl, thrinfo_t* thread ); @@ -60,9 +63,10 @@ void blx_gemm_cntl_free cntl_t* blx_gemm_cntl_create_node ( + rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/sandbox/ref99/cntl/blx_l3_cntl_if.c b/sandbox/old/ref99/old/cntl/blx_l3_cntl_if.c similarity index 69% rename from sandbox/ref99/cntl/blx_l3_cntl_if.c rename to sandbox/old/ref99/old/cntl/blx_l3_cntl_if.c index e961f088eb..cd2c1dab14 100644 --- a/sandbox/ref99/cntl/blx_l3_cntl_if.c +++ b/sandbox/old/ref99/old/cntl/blx_l3_cntl_if.c @@ -39,30 +39,16 @@ void blx_l3_cntl_create_if ( opid_t family, + pack_t schema_a, + pack_t schema_b, obj_t* a, obj_t* b, obj_t* c, + rntm_t* rntm, cntl_t* cntl_orig, cntl_t** cntl_use ) { - // This is part of a hack to support mixed domain in bli_gemm_front(). - // Sometimes we need to specify a non-standard schema for A and B, and - // we decided to transmit them via the schema field in the obj_t's - // rather than pass them in as function parameters. Once the values - // have been read, we immediately reset them back to their expected - // values for unpacked objects. Notice that we do this even if the - // caller passed in a custom control tree; that's because we still need - // to reset the pack schema of a and b, which were modified by the - // operation's _front() function. However, in order for this to work, - // the level-3 thread entry function (or omp parallel region) must - // alias thread-local copies of objects a and b. - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); - bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); - // If the control tree pointer is NULL, we construct a default // tree as a function of the operation family. if ( cntl_orig == NULL ) @@ -74,7 +60,7 @@ void blx_l3_cntl_create_if // If the user provided a control tree, create a copy and use it // instead (so that threads can use its local tree as a place to // cache things like pack mem_t entries). - *cntl_use = bli_cntl_copy( cntl_orig ); + *cntl_use = bli_cntl_copy( rntm, cntl_orig ); // Recursively set the family fields of the newly copied control tree // nodes. @@ -82,13 +68,10 @@ void blx_l3_cntl_create_if } } -void blx_l3_cntl_free_if +void blx_l3_cntl_free ( - obj_t* a, - obj_t* b, - obj_t* c, - cntl_t* cntl_orig, - cntl_t* cntl_use, + rntm_t rntm, + cntl_t* cntl_use, thrinfo_t* thread ) { @@ -96,13 +79,13 @@ void blx_l3_cntl_free_if // been created, so we now must free it. if ( cntl_orig == NULL ) { - blx_gemm_cntl_free( cntl_use, thread ); + blx_gemm_cntl_free( rntm, cntl_use, thread ); } else { // If the user provided a control tree, free the copy of it that // was created. - bli_cntl_free( cntl_use, thread ); + bli_cntl_free( rntm, cntl_use ); } } diff --git a/sandbox/ref99/cntl/blx_l3_cntl_if.h b/sandbox/old/ref99/old/cntl/blx_l3_cntl_if.h similarity index 92% rename from sandbox/ref99/cntl/blx_l3_cntl_if.h rename to sandbox/old/ref99/old/cntl/blx_l3_cntl_if.h index 87c29e3029..7d878574b4 100644 --- a/sandbox/ref99/cntl/blx_l3_cntl_if.h +++ b/sandbox/old/ref99/old/cntl/blx_l3_cntl_if.h @@ -35,20 +35,19 @@ void blx_l3_cntl_create_if ( opid_t family, + pack_t schema_a, + pack_t schema_b, obj_t* a, obj_t* b, obj_t* c, + rntm_t* rntm, cntl_t* cntl_orig, cntl_t** cntl_use ); -void blx_l3_cntl_free_if +void blx_l3_cntl_free ( - obj_t* a, - obj_t* b, - obj_t* c, - cntl_t* cntl_orig, - cntl_t* cntl_use, + rntm_t rntm, + cntl_t* cntl_use, thrinfo_t* thread ); - diff --git a/sandbox/ref99/cntl/blx_packm_cntl.c b/sandbox/old/ref99/old/cntl/blx_packm_cntl.c similarity index 97% rename from sandbox/ref99/cntl/blx_packm_cntl.c rename to sandbox/old/ref99/old/cntl/blx_packm_cntl.c index 85a7c85781..bf7f003a38 100644 --- a/sandbox/ref99/cntl/blx_packm_cntl.c +++ b/sandbox/old/ref99/old/cntl/blx_packm_cntl.c @@ -36,8 +36,9 @@ cntl_t* blx_packm_cntl_create_node ( - void* var_func, - void* packm_var_func, + rntm_t* rntm, + void_fp var_func, + void_fp packm_var_func, bszid_t bmid_m, bszid_t bmid_n, bool_t does_invert_diag, diff --git a/sandbox/ref99/cntl/blx_packm_cntl.h b/sandbox/old/ref99/old/cntl/blx_packm_cntl.h similarity index 96% rename from sandbox/ref99/cntl/blx_packm_cntl.h rename to sandbox/old/ref99/old/cntl/blx_packm_cntl.h index fbba97e1c4..22eb0497cc 100644 --- a/sandbox/ref99/cntl/blx_packm_cntl.h +++ b/sandbox/old/ref99/old/cntl/blx_packm_cntl.h @@ -34,8 +34,9 @@ cntl_t* blx_packm_cntl_create_node ( - void* var_func, - void* packm_var_func, + rntm_t* rntm, + void_fp var_func, + void_fp packm_var_func, bszid_t bmid_m, bszid_t bmid_n, bool_t does_invert_diag, diff --git a/sandbox/ref99/packm/blx_l3_packm.c b/sandbox/old/ref99/old/packm/blx_l3_packm.c similarity index 93% rename from sandbox/ref99/packm/blx_l3_packm.c rename to sandbox/old/ref99/old/packm/blx_l3_packm.c index 16df18c3ca..982e2d9631 100644 --- a/sandbox/ref99/packm/blx_l3_packm.c +++ b/sandbox/old/ref99/old/packm/blx_l3_packm.c @@ -45,13 +45,13 @@ void blx_l3_packm thrinfo_t* thread ) { - membrk_t* membrk; + pba_t* pba; packbuf_t pack_buf_type; mem_t* cntl_mem_p; siz_t size_needed; // FGVZ: Not sure why we need this barrier, but we do. - bli_thread_obarrier( thread ); + bli_thread_barrier( thread ); // Every thread initializes x_pack and determines the size of memory // block needed (which gets embedded into the otherwise "blank" mem_t @@ -71,7 +71,7 @@ void blx_l3_packm if ( size_needed == 0 ) return; // Query the memory broker from the context. - membrk = bli_cntx_get_membrk( cntx ); + pba = bli_cntx_get_pba( cntx ); // Query the pack buffer type from the control tree node. pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); @@ -91,9 +91,9 @@ void blx_l3_packm { // The chief thread acquires a block from the memory broker // and saves the associated mem_t entry to local_mem_s. - bli_membrk_acquire_m + bli_pba_acquire_m ( - membrk, + pba, size_needed, pack_buf_type, &local_mem_s @@ -102,7 +102,7 @@ void blx_l3_packm // Broadcast the address of the chief thread's local mem_t entry to // all threads. - local_mem_p = bli_thread_obroadcast( thread, &local_mem_s ); + local_mem_p = bli_thread_broadcast( thread, &local_mem_s ); // Save the contents of the chief thread's local mem_t entry to the // mem_t field in this thread's control tree node. @@ -130,10 +130,10 @@ void blx_l3_packm // The chief thread releases the existing block associated with // the mem_t entry in the control tree, and then re-acquires a // new block, saving the associated mem_t entry to local_mem_s. - bli_membrk_release( cntl_mem_p ); - bli_membrk_acquire_m + bli_pba_release( cntl_mem_p ); + bli_pba_acquire_m ( - membrk, + pba, size_needed, pack_buf_type, &local_mem_s @@ -142,7 +142,7 @@ void blx_l3_packm // Broadcast the address of the chief thread's local mem_t entry to // all threads. - local_mem_p = bli_thread_obroadcast( thread, &local_mem_s ); + local_mem_p = bli_thread_broadcast( thread, &local_mem_s ); // Save the chief thread's local mem_t entry to the mem_t field in // this thread's control tree node. @@ -155,7 +155,7 @@ void blx_l3_packm // will already have the cached values in their local control // trees' mem_t entries, currently pointed to by cntl_mem_p. - bli_thread_obarrier( thread ); + bli_thread_barrier( thread ); } } @@ -178,6 +178,6 @@ void blx_l3_packm ); // Barrier so that packing is done before computation. - bli_thread_obarrier( thread ); + bli_thread_barrier( thread ); } diff --git a/sandbox/ref99/packm/blx_l3_packm.h b/sandbox/old/ref99/old/packm/blx_l3_packm.h similarity index 100% rename from sandbox/ref99/packm/blx_l3_packm.h rename to sandbox/old/ref99/old/packm/blx_l3_packm.h diff --git a/sandbox/ref99/thread/blx_gemm_thread.c b/sandbox/old/ref99/old/thread/blx_gemm_thread.c similarity index 54% rename from sandbox/ref99/thread/blx_gemm_thread.c rename to sandbox/old/ref99/old/thread/blx_gemm_thread.c index 97123c0ec4..b5657aa4f2 100644 --- a/sandbox/ref99/thread/blx_gemm_thread.c +++ b/sandbox/old/ref99/old/thread/blx_gemm_thread.c @@ -38,6 +38,7 @@ // This code is enabled only when multithreading is enabled via OpenMP. #ifdef BLIS_ENABLE_OPENMP +#if 0 void blx_gemm_thread ( gemmint_t func, @@ -101,6 +102,129 @@ void blx_gemm_thread // by the global communicator's chief thread in bli_l3_thrinfo_free() // (called above). } +#endif +void blx_gemm_thread + ( + gemmint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + // This is part of a hack to support mixed domain in bli_gemm_front(). + // Sometimes we need to specify a non-standard schema for A and B, and + // we decided to transmit them via the schema field in the obj_t's + // rather than pass them in as function parameters. Once the values + // have been read, we immediately reset them back to their expected + // values for unpacked objects. + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, a ); + bli_obj_set_pack_schema( BLIS_NOT_PACKED, b ); + + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + //bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + + obj_t a_t, b_t, c_t; + cntl_t* cntl_use; + thrinfo_t* thread; + + // Alias thread-local copies of A, B, and C. These will be the objects + // we pass down the algorithmic function stack. Making thread-local + // aliases is highly recommended in case a thread needs to change any + // of the properties of an object without affecting other threads' + // objects. + bli_obj_alias_to( a, &a_t ); + bli_obj_alias_to( b, &b_t ); + bli_obj_alias_to( c, &c_t ); + + // Create a default control tree for the operation, if needed. + blx_l3_cntl_create_if( family, schema_a, schema_b, + &a_t, &b_t, &c_t, rntm_p, cntl, &cntl_use ); + + // Create the root node of the current thread's thrinfo_t structure. + blx_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); + + func + ( + alpha, + &a_t, + &b_t, + beta, + &c_t, + cntx, + rntm_p, + cntl_use, + thread + ); + + // Free the thread's local control tree. + blx_l3_cntl_free( rntm_p, cntl_use, thread ); + + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + + #endif diff --git a/sandbox/ref99/thread/blx_gemm_thread.h b/sandbox/old/ref99/old/thread/blx_gemm_thread.h similarity index 97% rename from sandbox/ref99/thread/blx_gemm_thread.h rename to sandbox/old/ref99/old/thread/blx_gemm_thread.h index 718735f4b2..0667626dfe 100644 --- a/sandbox/ref99/thread/blx_gemm_thread.h +++ b/sandbox/old/ref99/old/thread/blx_gemm_thread.h @@ -35,8 +35,10 @@ // gemm internal function type typedef void (*gemmint_t) ( + obj_t* alpha, obj_t* a, obj_t* b, + obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm, diff --git a/sandbox/ref99/vars/blx_gemm_blk_var1.c b/sandbox/old/ref99/old/vars/blx_gemm_blk_var1.c similarity index 97% rename from sandbox/ref99/vars/blx_gemm_blk_var1.c rename to sandbox/old/ref99/old/vars/blx_gemm_blk_var1.c index ef8c07b1d5..dc41b97fff 100644 --- a/sandbox/ref99/vars/blx_gemm_blk_var1.c +++ b/sandbox/old/ref99/old/vars/blx_gemm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/vars/blx_gemm_blk_var2.c b/sandbox/old/ref99/old/vars/blx_gemm_blk_var2.c similarity index 97% rename from sandbox/ref99/vars/blx_gemm_blk_var2.c rename to sandbox/old/ref99/old/vars/blx_gemm_blk_var2.c index f272952b01..d7d128c358 100644 --- a/sandbox/ref99/vars/blx_gemm_blk_var2.c +++ b/sandbox/old/ref99/old/vars/blx_gemm_blk_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/vars/blx_gemm_blk_var3.c b/sandbox/old/ref99/old/vars/blx_gemm_blk_var3.c similarity index 98% rename from sandbox/ref99/vars/blx_gemm_blk_var3.c rename to sandbox/old/ref99/old/vars/blx_gemm_blk_var3.c index 7eace4af8e..6e87862682 100644 --- a/sandbox/ref99/vars/blx_gemm_blk_var3.c +++ b/sandbox/old/ref99/old/vars/blx_gemm_blk_var3.c @@ -73,7 +73,7 @@ void blx_gemm_blk_var3 bli_thrinfo_sub_node( thread ) ); - bli_thread_obarrier( bli_thrinfo_sub_node( thread ) ); + bli_thread_barrier( bli_thrinfo_sub_node( thread ) ); // This variant executes multiple rank-k updates. Therefore, if the // internal beta scalar on matrix C is non-zero, we must use it diff --git a/sandbox/ref99/vars/blx_gemm_ker_var2.c b/sandbox/old/ref99/old/vars/blx_gemm_ker_var2.c similarity index 99% rename from sandbox/ref99/vars/blx_gemm_ker_var2.c rename to sandbox/old/ref99/old/vars/blx_gemm_ker_var2.c index 61842411ab..10c6b81ada 100644 --- a/sandbox/ref99/vars/blx_gemm_ker_var2.c +++ b/sandbox/old/ref99/old/vars/blx_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/vars/blx_gemm_packab.c b/sandbox/old/ref99/old/vars/blx_gemm_packab.c similarity index 100% rename from sandbox/ref99/vars/blx_gemm_packab.c rename to sandbox/old/ref99/old/vars/blx_gemm_packab.c diff --git a/sandbox/ref99/vars/blx_gemm_var.h b/sandbox/old/ref99/old/vars/blx_gemm_var.h similarity index 97% rename from sandbox/ref99/vars/blx_gemm_var.h rename to sandbox/old/ref99/old/vars/blx_gemm_var.h index 32b975d1a3..a2a3de9bbf 100644 --- a/sandbox/ref99/vars/blx_gemm_var.h +++ b/sandbox/old/ref99/old/vars/blx_gemm_var.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/vars/other/blx_gemm_ker_var2rr.c b/sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2rr.c similarity index 99% rename from sandbox/ref99/vars/other/blx_gemm_ker_var2rr.c rename to sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2rr.c index 8a5c1d1564..7cbd402e05 100644 --- a/sandbox/ref99/vars/other/blx_gemm_ker_var2rr.c +++ b/sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2rr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/ref99/vars/other/blx_gemm_ker_var2sl.c b/sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2sl.c similarity index 99% rename from sandbox/ref99/vars/other/blx_gemm_ker_var2sl.c rename to sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2sl.c index 4b0523e37c..2d46886b76 100644 --- a/sandbox/ref99/vars/other/blx_gemm_ker_var2sl.c +++ b/sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2sl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/sandbox/power10/POWER10.md b/sandbox/power10/POWER10.md new file mode 100644 index 0000000000..cdfb09e7d2 --- /dev/null +++ b/sandbox/power10/POWER10.md @@ -0,0 +1,67 @@ +### Low Precision POWER10 Kernels + +This is a special BLIS Sandbox that allows users to call POWER10 reduced precision/integer `GEMM` kernels. + +Supported kernels: `IEEE float16 (bli_shgemm), bfloat16 (bli_sbgemm), int16 (bli_i16gemm), int8 (bli_i8gemm), int4 (bli_i4gemm)`. + +#### Introduction + +This document describes how the low precision POWER10 `gemm` kernels are implemented and explains how to call the POWER10 `GEMM` kernels. + +**Important: These kernels does not have the full functionality of BLIS. The kernels can only perform single threaded, no transpose, GEMM.** + +#### Implementation + +The kernels are implemented in `gemm.c`. They are instantiated with macro templates. The main template is called `GENERIC_GEMM`. This template is used to create the 5-loop `gemm` function. + +#### Reduced precision/integer Types + +| BLIS type | BLIS char | Type definition | Used to represent... | +|:-----------|:----------|:---------------------------------------|:-------------------------------------| +| `float16` | `h` | `typedef union { uint16_t v; struct { uint16_t m:10; uint16_t e:5; uint16_t s:1} bits; }` | IEEE half-precision real numbers | +| `bfloat16` | `b` | `typedef union { uint16_t v; struct { uint16_t m:7; uint16_t e:8; uint16_t s:1; } bits; }` | Google's half-precision real numbers | +| `int16` | `i16` | `int16_t` | 16 bit integers | +| `int8` | `i8` | `int8_t` | 8 bit integers | +| `int4` | `i4` | `typedef union{ uint8_t v; struct { uint8_t nib1:4; uint8_t nib2:4; } bits; }` | 4 bit integers | + +#### Reduced Precision/Integer API + +The API that is used for the reduced precision/integer POWER10 `GEMM` kernels is similar to the existing [BLIS basic typed API](https://github.com/flame/blis/blob/master/docs/BLISTypedAPI.md). The main difference is the POWER10 kernels expect two types: `ctype_in` and `ctype_out`. + +Thus the new `gemm` call looks like the following: + +``` +void bli_??gemm + ( + trans_t transa, + trans_t transb, + dim_t m, + dim_t n, + dim_t k, + ctype_out* alpha, + ctype_in* a, inc_t rsa, inc_t csa, + ctype_in* b, inc_t rsb, inc_t csb, + ctype_out* beta, + ctype_out* c, inc_t rsc, inc_t csc + ); +``` + +`??` is meant to replaced with the kernel prefix. + +#### How To Build The Sandbox + +Add the following flags when running the configure script to build BLIS correctly. + +`CFLAGS="-fPIC -std=c99 -D_ISOC11_SOURCE -D_POSIX_C_SOURCE=200112L" -s power10` + +Ensure that you have GCC 10.2 or greater. + + +#### P10 Testsuite + +In `p10_testsuite`, there are performance gathering and correctness checking programs for the POWER10 reduced precision/integer `GEMM` kernels. By default, the performance gathering and correctness checking is done over square matrices ranging from 80 to 4000 in increments of 80. Performance is measured in GFLOPs, and correctness is measured using the BLIS method (detailed in `blis/testsuite/test_gemm.c`). + +#### References + +* [bfloat16 wiki](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) +* [IEEE float16 wiki](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) \ No newline at end of file diff --git a/sandbox/power10/bli_gemm_ex.c b/sandbox/power10/bli_gemm_ex.c new file mode 100644 index 0000000000..3334dc4a53 --- /dev/null +++ b/sandbox/power10/bli_gemm_ex.c @@ -0,0 +1,79 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// Given the current architecture of BLIS sandboxes, bli_gemm_ex() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented functionally identically to the +// function that it overrides in frame/3/bli_l3_oapi_ex.c. This means that +// we are forgoing the option of customizing the implementations that +// underlie bli_gemm() and bli_?gemm() (which both call bli_gemm_ex()). +// Any new code defined in this sandbox directory, however, will be +// included in the BLIS. + +#include "blis.h" + +void bli_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front + ( + alpha, a, b, beta, c, cntx, rntm, NULL + ); +} + diff --git a/sandbox/power10/bli_sandbox.h b/sandbox/power10/bli_sandbox.h new file mode 100644 index 0000000000..22d293d130 --- /dev/null +++ b/sandbox/power10/bli_sandbox.h @@ -0,0 +1,102 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SANDBOX_H +#define BLIS_SANDBOX_H + +#include "blis.h" +#include "gemm_prototypes.h" + +// NOTE: This header is the only header required to be present in the sandbox +// implementation directory. + +// int4 +typedef union +{ + uint8_t v; + struct + { + uint8_t nib1:4; + uint8_t nib2:4; + } bits; +} nibbles; + +// brain float16 +typedef union +{ + uint16_t v; + struct + { + uint16_t m:7; + uint16_t e:8; + uint16_t s:1; + } bits; +} bfloat16; + +// ieee float16 +typedef union +{ + uint16_t v; + struct + { + uint16_t m:10; + uint16_t e:5; + uint16_t s:1; + } bits; +} float16; + +#define P10_PG_SIZE 4096 + +// microkernel prototypes +GEMM_UKR_PROT2( bfloat16, float, sb, gemm_power10_mma_8x16 ) +GEMM_UKR_PROT2( float16, float, sh, gemm_power10_mma_8x16 ) +GEMM_UKR_PROT2( int16_t, int32_t, i16, gemm_power10_mma_8x16 ) +GEMM_UKR_PROT2( int8_t, int32_t, i8, gemm_power10_mma_8x16 ) +GEMM_UKR_PROT2( nibbles, int32_t, i4, gemm_power10_mma_8x16 ) + +// gemm kernel prototypes +GEMM_FUNC_PROT( float16, float, sh); +GEMM_FUNC_PROT( bfloat16, float, sb); +GEMM_FUNC_PROT( int16_t, int32_t, i16); +GEMM_FUNC_PROT( int8_t, int32_t, i8); +GEMM_FUNC_PROT( nibbles, int32_t, i4); + +// pack kernel prototypes +PACK_MACRO_PROTO(sb, bfloat16) +PACK_MACRO_PROTO(sh, float16) +PACK_MACRO_PROTO(i16, int16_t) +PACK_MACRO_PROTO(i8, int8_t) +PACK_MACRO_PROTO(i4, nibbles) + +#endif diff --git a/sandbox/power10/gemm.c b/sandbox/power10/gemm.c new file mode 100644 index 0000000000..7b5983ef91 --- /dev/null +++ b/sandbox/power10/gemm.c @@ -0,0 +1,128 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "gemm_template.h" +#include "bli_sandbox.h" + + +GENERIC_GEMM( + sb, // kernel name prefix + bfloat16, // input type + float, // output type + (pb/2 + pb%2), // innermost loop iterations + sb_pack_a, + sb_pack_b, // pack kernel for B + bli_sbgemm_power10_mma_8x16, // microkernel function name + 2, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 3328, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + +GENERIC_GEMM( + sh, // kernel name prefix + float16, // input type + float, // output type + (pb/2 + pb%2), // innermost loop iterations + sh_pack_a, // pack kernel for A + sh_pack_b, // pack kernel for B + bli_shgemm_power10_mma_8x16, // microkernel function name + 2, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 3328, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + +GENERIC_GEMM( + i16, // kernel name prefix + int16_t, // input type + int, // output type + (pb/2 + pb%2), // innermost loop iterations + i16_pack_a, // pack kernel for A + i16_pack_b, // pack kernel for B + bli_i16gemm_power10_mma_8x16, // microkernel function name + 2, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 3328, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + +GENERIC_GEMM( + i8, // kernel name prefix + int8_t, // input type + int, // output type + (pb/4 + (pb%4>0)), // innermost loop iterations + i8_pack_a, // pack kernel for A + i8_pack_b, // pack kernel for B + bli_i8gemm_power10_mma_8x16, // microkernel function name + 4, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 6656, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + +GENERIC_GEMM( + i4, // kernel name prefix + nibbles, // input type + int, // output type + (pb/8 + (pb%8>0)), // innermost loop iterations + i4_pack_a, // pack kernel for A + i4_pack_b, // pack kernel for B + bli_i4gemm_power10_mma_8x16, // microkernel function name + 8, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 6656, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + diff --git a/sandbox/power10/gemm_prototypes.h b/sandbox/power10/gemm_prototypes.h new file mode 100644 index 0000000000..f1bdac57b6 --- /dev/null +++ b/sandbox/power10/gemm_prototypes.h @@ -0,0 +1,76 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// BLIS GEMM function naming scheme +#define GEMM_FUNC_NAME_(ch) bli_ ## ch ## gemm +#define GEMM_FUNC_NAME(ch) GEMM_FUNC_NAME_(ch) + +// BLIS GEMM function prototype macro +#define GEMM_FUNC_PROT(DTYPE_IN, DTYPE_OUT, ch) \ + void GEMM_FUNC_NAME(ch) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + DTYPE_OUT* alpha, \ + DTYPE_IN* a, inc_t rsa, inc_t csa, \ + DTYPE_IN* b, inc_t rsb, inc_t csb, \ + DTYPE_OUT* beta, \ + DTYPE_OUT* c, inc_t rsc, inc_t csc \ + ) + +// Pack routine naming scheme +#define PACK_FUNC_NAME_(ch, mat) ch ## _pack_ ## mat +#define PACK_FUNC_NAME(ch, mat) PACK_FUNC_NAME_(ch, mat) + +// Pack routine prototype +#define PACK_MACRO_PROTO(ch, DTYPE_IN) \ +\ +void PACK_FUNC_NAME(ch, a) \ + ( \ + dim_t MR, \ + int m, int k, \ + DTYPE_IN* ap, int rs_a, int cs_a, \ + DTYPE_IN* apack \ + ); \ +\ +void PACK_FUNC_NAME(ch, b) \ + ( \ + dim_t NR, \ + int k, int n, \ + DTYPE_IN* bp, int rs_b, int cs_b, \ + DTYPE_IN* bpack \ + ); diff --git a/sandbox/power10/gemm_template.h b/sandbox/power10/gemm_template.h new file mode 100644 index 0000000000..eb0ef24bbb --- /dev/null +++ b/sandbox/power10/gemm_template.h @@ -0,0 +1,166 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +/* + Macro function template for creating BLIS GEMM kernels using the Goto method. + + This GEMM template assumes that the matrices are both not transposed. + + ch - kernel name prefix + DTYPE_IN, DTYPE_OUT - datatypes of the input and output operands respectively + NEW_PB - number of iterations of the innermost loop + PACK_A, PACK_B - pack kernels names + MICROKERNEL - microkernel function name + K_MMA - number of outer products performed by an instruction + MR, NR, MC, KC, NC - Cache blocking parameters + B_ALIGN, A_ALIGN - Extra byte alignment for the pack matrix buffers +*/ +#define GENERIC_GEMM( \ + ch, \ + DTYPE_IN, \ + DTYPE_OUT, \ + NEW_PB, \ + PACK_A, \ + PACK_B, \ + MICROKERNEL, \ + K_MMA, \ + MR, \ + NR, \ + MC, \ + KC, \ + NC, \ + B_ALIGN, \ + A_ALIGN \ +) \ +\ +void GEMM_FUNC_NAME(ch) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + DTYPE_OUT* alpha, \ + DTYPE_IN* a, inc_t rsa, inc_t csa, \ + DTYPE_IN* b, inc_t rsb, inc_t csb, \ + DTYPE_OUT* beta, \ + DTYPE_OUT* c, inc_t rsc, inc_t csc \ + ) \ +{ \ + DTYPE_IN * restrict btilde_sys = ( DTYPE_IN *) aligned_alloc( P10_PG_SIZE, B_ALIGN + KC * NC * sizeof( DTYPE_IN ) ); \ + DTYPE_IN * restrict atilde_sys = ( DTYPE_IN *) aligned_alloc( P10_PG_SIZE, A_ALIGN + MC * KC * sizeof( DTYPE_IN ) ); \ + \ + DTYPE_IN * restrict btilde_usr = ( DTYPE_IN *)((char *)btilde_sys + B_ALIGN); \ + DTYPE_IN * restrict atilde_usr = ( DTYPE_IN *)((char *)atilde_sys + A_ALIGN); \ + \ + const int rstep_c = MC * rsc; \ + const int cstep_c = NC * csc; \ + \ + const int rstep_a = MC * rsa; \ + const int cstep_a = KC * csa; \ + \ + const int rstep_b = KC * rsb; \ + const int cstep_b = NC * csb; \ + \ + const int rstep_mt_c = MR * rsc; \ + const int cstep_mt_c = NR * csc; \ + \ + DTYPE_OUT * restrict cblock = c; \ + DTYPE_IN * restrict bblock = b; \ + \ + for ( int jc=0; jcv = 0; \ + dest++; + +// zero out 4 nibbles struct +#define zero_out_dest(dest) \ + memset(dest, 0, 4); + + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////// Col Major Order Macros //////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/* + + The following macros handle the case when there is a full size panel + (ib/jb == MR/NR) and no edge case (k%8 == 0). + +*/ + +#define col_m_order_1(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+0)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+1)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+2)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+3)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+4)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+5)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+6)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+7)*cs].bits.nib1; \ + dest++; + +#define col_m_order_2(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+0)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+1)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+2)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+3)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+4)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+5)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (p_idx+6)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (p_idx+7)*cs].bits.nib2; \ + dest++; + +/* + + The following macros handle the case when there is a full size panel + (ib/jb == MR/NR) and there is an edge case (k%8 != 0). + +*/ + +#define col_m_order_1_kleft7(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-7)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-6)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-5)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-4)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-3)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-2)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest->bits.nib2 = 0; \ + dest++; + +#define col_m_order_2_kleft7(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-7)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-6)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-5)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-4)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-3)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-2)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest->bits.nib2 = 0; \ + dest++; + +#define col_m_order_1_kleft6(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-6)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-5)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-4)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-3)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-2)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest++; \ + zero_out_full(dest); + +#define col_m_order_2_kleft6(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-6)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-5)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-4)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-3)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-2)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest++; \ + zero_out_full(dest); + +#define col_m_order_1_kleft5(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-5)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-4)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-3)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-2)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest->bits.nib2 = 0; \ + dest++; \ + zero_out_full(dest); + +#define col_m_order_2_kleft5(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-5)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-4)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-3)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-2)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest->bits.nib2 = 0; \ + dest++; \ + zero_out_full(dest); + +#define col_m_order_1_kleft4(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-4)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-3)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-2)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_2_kleft4(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-4)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-3)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-2)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_1_kleft3(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-3)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-2)*cs].bits.nib1; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest->bits.nib2 = 0; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_2_kleft3(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-3)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-2)*cs].bits.nib2; \ + dest++; \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest->bits.nib2 = 0; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_1_kleft2(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-2)*cs].bits.nib1; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_2_kleft2(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-2)*cs].bits.nib2; \ + dest->bits.nib2 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_1_kleft1(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib1; \ + dest->bits.nib2 = 0; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); \ + zero_out_full(dest); + +#define col_m_order_2_kleft1(dest, matrix, rs_mul, rs, cs) \ + dest->bits.nib1 = matrix[rs_mul*rs + (k-1)*cs].bits.nib2; \ + dest->bits.nib2 = 0; \ + dest++; \ + zero_out_full(dest); \ + zero_out_full(dest); \ + zero_out_full(dest); + +/* + + + The following macros are used when we have a full panel (ib == MR) + and we need to handle an edge case (k%8 != 0). + + The MR loop is unrolled resulting in the stream of macros. + +*/ + +#define apad_col_kleft7(dest, matrix, rs, cs) \ + col_m_order_1_kleft7(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (i+3), rs, cs); + +#define apad_col_kleft6(dest, matrix, rs, cs) \ + col_m_order_1_kleft6(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (i+3), rs, cs); + +#define apad_col_kleft5(dest, matrix, rs, cs) \ + col_m_order_1_kleft5(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (i+3), rs, cs); + +#define apad_col_kleft4(dest, matrix, rs, cs) \ + col_m_order_1_kleft4(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (i+3), rs, cs); + +#define apad_col_kleft3(dest, matrix, rs, cs) \ + col_m_order_1_kleft3(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (i+3), rs, cs); + +#define apad_col_kleft2(dest, matrix, rs, cs) \ + col_m_order_1_kleft2(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (i+3), rs, cs); + +#define apad_col_kleft1(dest, matrix, rs, cs) \ + col_m_order_1_kleft1(dest, matrix, (i ), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (i ), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (i+1), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (i+1), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (i+2), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (i+2), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (i+3), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (i+3), rs, cs); + +/* + + The following macros are used when we have a full panel (jb == NR) + and we need to handle an edge case (k%8 != 0). + + The NR loop is unrolled resulting in the stream of macros. + +*/ + +#define bpad_col_kleft7(dest, matrix, rs, cs) \ + col_m_order_1_kleft7(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft7(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft7(dest, matrix, (j+7), rs, cs); + +#define bpad_col_kleft6(dest, matrix, rs, cs) \ + col_m_order_1_kleft6(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft6(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft6(dest, matrix, (j+7), rs, cs); + +#define bpad_col_kleft5(dest, matrix, rs, cs) \ + col_m_order_1_kleft5(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft5(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft5(dest, matrix, (j+7), rs, cs); + +#define bpad_col_kleft4(dest, matrix, rs, cs) \ + col_m_order_1_kleft4(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft4(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft4(dest, matrix, (j+7), rs, cs); + +#define bpad_col_kleft3(dest, matrix, rs, cs) \ + col_m_order_1_kleft3(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft3(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft3(dest, matrix, (j+7), rs, cs); + +#define bpad_col_kleft2(dest, matrix, rs, cs) \ + col_m_order_1_kleft2(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft2(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft2(dest, matrix, (j+7), rs, cs); + +#define bpad_col_kleft1(dest, matrix, rs, cs) \ + col_m_order_1_kleft1(dest, matrix, (j ), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j ), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+1), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+1), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+2), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+2), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+3), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+3), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+4), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+4), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+5), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+5), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+6), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+6), rs, cs); \ + col_m_order_1_kleft1(dest, matrix, (j+7), rs, cs); \ + col_m_order_2_kleft1(dest, matrix, (j+7), rs, cs); + + +/* + + The following macros handle non full size panels (ib/jb != MR/NR) and + edge cases (k%8 != 0). + +*/ + +#define edge(edgefun, dest, matrix, panel, left, rs, cs) \ + for (int ir=0; ir>= 13; // Align mantissa on MSB + t2 >>= 16; // Shift sign bit into position + + t1 -= 0x1c000; // Adjust bias + + t1 = (t3 < 0x38800000) ? 0 : t1; + t1 = (t3 > 0x47000000) ? 0x7bff : t1; + t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero + + t1 |= t2; // Re-insert sign bit + + f16_out.v = t1; + return f16_out; +} + + +// cast float to bfloat16 +bfloat16 cast_f32_to_bf16 (float val) +{ + bfloat16 bf16; + float32_s f32; + f32.v = val; + bf16.bits.s = f32.bits.s; + bf16.bits.e = f32.bits.e; + bf16.bits.m = f32.bits.m >> 16; + return bf16; +} + +// cast bfloat16 to float +float cast_bf16_to_f32(bfloat16 val) +{ + float32_s f32; + f32.bits.s = val.bits.s; + f32.bits.e = val.bits.e; + f32.bits.m = val.bits.m << 16; + return f32.v; +} + +// cast a nibbles struct to a float array +void cast_i4_to_f32(float *fvals, nibbles vals) +{ + int8_t val0 = vals.bits.nib1; + int8_t val1 = vals.bits.nib2; + + val0 = (val0 >= 8 ? val0 - 16 : val0); + val1 = (val1 >= 8 ? val1 - 16 : val1); + + fvals[0] = (float) val0; + fvals[1] = (float) val1; +} + +// condense two float vals to a nibbles struct +nibbles cast_f32_to_i4(float val0, float val1) +{ + nibbles vals; + + int8_t val0_ = ((int8_t)val0) & 0xf0; + int8_t val1_ = ((int8_t)val1) & 0xf0; + + vals.bits.nib1 = val0_; + vals.bits.nib2 = val1_; + + return vals; +} + +// cast float matrix to float nibbles +void cast_f32_to_i4m(float *a_float, nibbles *a, int num_elems) +{ + int j=0; + for(int i=0; i +// print kernel name +const char* get_kernel_name(int kernel_id) +{ + switch (kernel_id) + { + case FLOAT16 : return "bli_shgemm"; + case BFLOAT16: return "bli_sbgemm"; + case INT16 : return "bli_i16gemm"; + case INT8 : return "bli_i8gemm"; + case INT4 : return "bli_i4gemm"; + default: printf("INCORRECT KERNEL ID\n"); exit(-1); + } +} + +// normalize the vector using the forbenious norm +void normalize_vec(float *t, int n) +{ + // normalize t + float norm_factor; + bli_snormfv(n, t, 1, &norm_factor); + // round up to closest power of 2 + norm_factor = 1 / (pow( 2.0, ceil( log2( norm_factor ) ) )); + bli_sscalv(BLIS_NO_CONJUGATE, n, &norm_factor, t, 1); +} + + // Pre-conditions: + // - a is randomized. + // - b is randomized. + // - c_orig is randomized. + // Note: + // - alpha and beta should have non-zero imaginary components in the + // complex cases in order to more fully exercise the implementation. + // + // Under these conditions, we assume that the implementation for + // + // C := beta * C_orig + alpha * transa(A) * transb(B) + // + // is functioning correctly if + // + // normfv( v - z ) + // + // is negligible, where + // + // v = C * t + // z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t + // = beta * C_orig * t + alpha * transa(A) * transb(B) * t + // = beta * C_orig * t + alpha * transa(A) * w + // = beta * C_orig * t + z +float get_resid( + int m, int n, int k, + float *a, int rsa, int csa, + float *b, int rsb, int csb, + float *c, int rsc, int csc, + float *c_orig, + float *alpha, float *beta +) +{ + + float t[n], v[m], w[k], z[m]; + float one = 1.0, zero = 0.0; + + bli_srandv(n, t, 1); + + // normalize so that the values are at the same precision of the input values + normalize_vec(t, n); + + // v = C * t + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + m, + n, + &one, + c, rsc, csc, + t, 1, + &zero, + v, 1 + ); + + // w = B * t + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + k, + n, + &one, + b, rsb, csb, + t, 1, + &zero, + w, 1 + ); + + // z = alpha * A * w + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + m, + k, + alpha, + a, rsa, csa, + w, 1, + &zero, + z, 1 + ); + + // z += beta * C_orig * t + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + m, + n, + beta, + c_orig, rsc, csc, + t, 1, + &one, + z, 1 + ); + + // v = v - z + bli_ssubv ( + BLIS_NO_CONJUGATE, + m, + z, 1, + v, 1 + ); + + // norm = normfv(v) + float norm; + bli_snormfv ( + m, + v, 1, + &norm + ); + + return norm; +} + + +// test to see if the result from a BLIS GEMM kernel is correct for a given m x n x k mat-mul +// assumes the matrices are of type float +// assumes the matrices were randomized and normalized +void correctness_checker( + int m, int n, int k, + float *a, int rsa, int csa, + float *b, int rsb, int csb, + float *c_orig, int rsc, int csc, + float *c_ans, + float alpha, float beta +) +{ + double start, end; + + start = bli_clock(); + float resid = get_resid ( + m, n, k, + a, rsa, csa, + b, rsb, csb, + c_ans, rsc, csc, + c_orig, + &alpha, &beta + ); + end = bli_clock(); + + printf("%d, %d, %d, %8.4le\n", m,n,k, resid); +} + + +// create all the correctness checking functions for each kernel +GEN_FP_COR_KERNEL(sb, bli_sbgemm, bfloat16, cast_f32_to_bf16m, cast_bf16_to_f32m); +GEN_FP_COR_KERNEL(sh, bli_shgemm, float16, cast_f32_to_f16m, cast_f16_to_f32m); +GEN_I_COR_KERNEL(i16, bli_i16gemm, int16_t, cast_f32_to_i16m, cast_i16_to_f32m); +GEN_I_COR_KERNEL(i8, bli_i8gemm, int8_t, cast_f32_to_i8m, cast_i8_to_f32m); + +// correctness template for int types +void i4correctness_kernel (int m, int n, int k) +{ + if(n%2 != 0) + { + printf("int4 can't handle odd sizes in the data-order dimension"); + exit(-1); + } + + int rsa = k, csa = 1, + rsb = n, csb = 1, + rsc = n, csc = 1; + + nibbles *a, *b; + + int32_t *c_ans, *c_orig, alpha, beta; + + float *a_float, *b_float, + *c_ans_float, *c_orig_float; + + /* buffers that will be passed into the kernel */ + // int4 buffers only need half the space to store all the elements + a = (nibbles *) malloc (m * (k/2) * sizeof(nibbles)); + b = (nibbles *) malloc (k * (n/2) * sizeof(nibbles)); + + c_ans = (int32_t *) malloc (m * n * sizeof(int32_t)); + c_orig = (int32_t *) malloc (m * n * sizeof(int32_t)); + + /* std format buffers that will be used by the correctness checker */ + a_float = (float *) malloc (m * k * sizeof(float)); + b_float = (float *) malloc (k * n * sizeof(float)); + c_ans_float = (float *) malloc (m * n * sizeof(float)); + c_orig_float = (float *) malloc (m * n * sizeof(float)); + + /* randomize matrices with float vals */ + bli_srandv(m*k, a_float, 1); + bli_srandv(k*n, b_float, 1); + bli_srandv(m*n, c_orig_float, 1); + + /* normalize the matrices */ + normalize_vec(a_float, m*k); + normalize_vec(b_float, k*n); + normalize_vec(c_orig_float, m*n); + + /* cast the float buffers into the buffers for the kernel */ + cast_f32_to_i4m (a_float, a, m*k); + cast_f32_to_i4m (b_float, b, k*n); + + /* cast float buffers to support int values */ + cast_f32_to_i32m(c_orig_float, c_orig, m*n); + cast_i32_to_f32m(c_orig, c_orig_float, m*n); + + /* cast the kernel buffers into the float buffers to ensure that the values match */ + cast_i4_to_f32m (a, a_float, m*k); + cast_i4_to_f32m (b, b_float, k*n); + + /* init alpha and beta */ + alpha = 1; + beta = 1; + + /* run kernel to get result in c_ans */ + // strides need to be adjusted since 1 element stores 2 values + memcpy(c_ans, c_orig, m * n * sizeof(int)); + bli_i4gemm( + BLIS_NO_TRANSPOSE, + BLIS_NO_TRANSPOSE, + m, + n, + k, + &alpha, + a, rsa/2, csa, + b, rsb/2, csb, + &beta, + c_ans, rsc, csc + ); + + /* cast integer result into float buffer since float is our std format for correctness checking */ + cast_i32_to_f32m(c_ans, c_ans_float, m*n); + + /* using the BLIS GEMM correctness check method, get the resid */ + correctness_checker( + m, n, k, + a_float, rsa, csa, + b_float, rsb, csb, + c_orig_float, rsc, csc, + c_ans_float, + (float) alpha, (float) beta + ); + + free(a); + free(b); + free(c_ans); + free(c_orig); + free(a_float); + free(b_float); + free(c_ans_float); + free(c_orig_float); +} + +// using the DATATYPE enum, gather test the correctness of the respective GEMM kernel +void run_correctness_kernel(int kernel_id, int m, int n, int k) +{ + switch (kernel_id) + { + case FLOAT16 : shcorrectness_kernel(m, n, k); break; + case BFLOAT16: sbcorrectness_kernel(m, n, k); break; + case INT16 : i16correctness_kernel(m, n, k); break; + case INT8 : i8correctness_kernel(m, n, k); break; + case INT4 : i4correctness_kernel(m, n, k); break; + default: break; + } +} + +void test_correctness(int kernel_id, int start, int end, int inc) +{ + printf("%s correctness test\n", get_kernel_name(kernel_id)); + printf("m, n, k, resid\n"); + int m,n,k; + for (int p=start; p<=end; p+=inc) + { + m=n=k=p; + run_correctness_kernel(kernel_id, m, n, k); + } +} + +// correctness test for bfloat16 gemm +int main(int argc, char *argv[]) +{ + + test_correctness(FLOAT16, 80, 4000, 80); + test_correctness(BFLOAT16, 80, 4000, 80); + test_correctness(INT16, 80, 4000, 80); + test_correctness(INT8, 80, 4000, 80); + test_correctness(INT4, 80, 4000, 80); +} diff --git a/sandbox/power10/p10_testsuite/correctness.h b/sandbox/power10/p10_testsuite/correctness.h new file mode 100644 index 0000000000..aea647848a --- /dev/null +++ b/sandbox/power10/p10_testsuite/correctness.h @@ -0,0 +1,176 @@ +// templates for generating correctness checking functions that check the correctness of GEMM kernels +// using the BLIS GEMM correctness method + +#define COR_KERNEL_NAME_(ch) ch ## correctness_kernel +#define COR_KERNEL_NAME(ch) COR_KERNEL_NAME_(ch) + + +// correctness template for float types +#define GEN_FP_COR_KERNEL(ch, kernel, input_t, DOWN_CAST, UP_CAST) \ +void COR_KERNEL_NAME(ch) (int m, int n, int k) \ +{ \ + int rsa = k, csa = 1, \ + rsb = n, csb = 1, \ + rsc = n, csc = 1; \ +\ + input_t *a, *b; \ +\ + float *a_float, *b_float, \ + *c_ans_float, *c_orig_float, \ + alpha, beta; \ +\ + /* buffers that will be passed into the kernel */ \ + a = (input_t *) malloc (m * k * sizeof(input_t)); \ + b = (input_t *) malloc (k * n * sizeof(input_t)); \ +\ + /* std format buffers that will be used by the correctness checker */ \ + a_float = (float *) malloc (m * k * sizeof(float)); \ + b_float = (float *) malloc (k * n * sizeof(float)); \ + c_ans_float = (float *) malloc (m * n * sizeof(float)); \ + c_orig_float = (float *) malloc (m * n * sizeof(float)); \ +\ + /* randomize matrices with float vals */ \ + bli_srandv(m*k, a_float, 1); \ + bli_srandv(k*n, b_float, 1); \ + bli_srandv(m*n, c_orig_float, 1); \ +\ + /* normalize the matrices */ \ + normalize_vec(a_float, m*k); \ + normalize_vec(b_float, k*n); \ + normalize_vec(c_orig_float, m*n); \ +\ + /* cast the float buffers into the buffers for the kernel */ \ + DOWN_CAST (a_float, a, m*k); \ + DOWN_CAST (b_float, b, k*n); \ +\ + /* cast the kernel buffers into the float buffers to ensure that the values match */ \ + UP_CAST (a, a_float, m*k); \ + UP_CAST (b, b_float, k*n); \ +\ + /* init alpha and beta */ \ + alpha = 1; \ + beta = 1; \ +\ + memcpy(c_ans_float, c_orig_float, m * n * sizeof(float)); \ + kernel( \ + BLIS_NO_TRANSPOSE, \ + BLIS_NO_TRANSPOSE, \ + m, \ + n, \ + k, \ + &alpha, \ + a, rsa, csa, \ + b, rsb, csb, \ + &beta, \ + c_ans_float, rsc, csc \ + ); \ +\ + correctness_checker( \ + m, n, k, \ + a_float, rsa, csa, \ + b_float, rsb, csb, \ + c_orig_float, rsc, csc, \ + c_ans_float, \ + alpha, beta \ + ); \ +\ + free(a); \ + free(b); \ + free(a_float); \ + free(b_float); \ + free(c_ans_float); \ + free(c_orig_float); \ +\ +} + +// correctness template for int types +#define GEN_I_COR_KERNEL(ch, kernel, input_t, DOWN_CAST, UP_CAST) \ +void COR_KERNEL_NAME(ch) (int m, int n, int k) \ +{ \ + int rsa = k, csa = 1, \ + rsb = n, csb = 1, \ + rsc = n, csc = 1; \ +\ + input_t *a, *b; \ +\ + int32_t *c_ans, *c_orig, alpha, beta; \ +\ + float *a_float, *b_float, \ + *c_ans_float, *c_orig_float; \ +\ + /* buffers that will be passed into the kernel */ \ + a = (input_t *) malloc (m * k * sizeof(input_t)); \ + b = (input_t *) malloc (k * n * sizeof(input_t)); \ + c_ans = (int32_t *) malloc (m * n * sizeof(int32_t)); \ + c_orig = (int32_t *) malloc (m * n * sizeof(int32_t)); \ +\ + /* std format buffers that will be used by the correctness checker */ \ + a_float = (float *) malloc (m * k * sizeof(float)); \ + b_float = (float *) malloc (k * n * sizeof(float)); \ + c_ans_float = (float *) malloc (m * n * sizeof(float)); \ + c_orig_float = (float *) malloc (m * n * sizeof(float)); \ +\ + /* randomize matrices with float vals */ \ + bli_srandv(m*k, a_float, 1); \ + bli_srandv(k*n, b_float, 1); \ + bli_srandv(m*n, c_orig_float, 1); \ +\ + /* normalize the matrices */ \ + normalize_vec(a_float, m*k); \ + normalize_vec(b_float, k*n); \ + normalize_vec(c_orig_float, m*n); \ +\ + /* cast the float buffers into the buffers for the kernel */ \ + DOWN_CAST (a_float, a, m*k); \ + DOWN_CAST (b_float, b, k*n); \ +\ + /* cast float buffers to support int values */ \ + cast_f32_to_i32m(c_orig_float, c_orig, m*n); \ + cast_i32_to_f32m(c_orig, c_orig_float, m*n); \ +\ + /* cast the kernel buffers into the float buffers to ensure that the values match */ \ + UP_CAST (a, a_float, m*k); \ + UP_CAST (b, b_float, k*n); \ +\ + /* init alpha and beta */ \ + alpha = 1; \ + beta = 1; \ +\ + /* run kernel to get result in c_ans */ \ + memcpy(c_ans, c_orig, m * n * sizeof(int)); \ + kernel( \ + BLIS_NO_TRANSPOSE, \ + BLIS_NO_TRANSPOSE, \ + m, \ + n, \ + k, \ + &alpha, \ + a, rsa, csa, \ + b, rsb, csb, \ + &beta, \ + c_ans, rsc, csc \ + ); \ +\ + /* cast integer result into float buffer since float is our std format for correctness checking */ \ + cast_i32_to_f32m(c_ans, c_ans_float, m*n); \ +\ + /* using the BLIS GEMM correctness check method, get the resid */ \ + correctness_checker( \ + m, n, k, \ + a_float, rsa, csa, \ + b_float, rsb, csb, \ + c_orig_float, rsc, csc, \ + c_ans_float, \ + (float) alpha, (float) beta \ + ); \ +\ + free(a); \ + free(b); \ + free(c_ans); \ + free(c_orig); \ + free(a_float); \ + free(b_float); \ + free(c_ans_float); \ + free(c_orig_float); \ +\ +} diff --git a/sandbox/power10/p10_testsuite/performance.c b/sandbox/power10/p10_testsuite/performance.c new file mode 100644 index 0000000000..25f1c3ff2a --- /dev/null +++ b/sandbox/power10/p10_testsuite/performance.c @@ -0,0 +1,103 @@ +/* + + This program is designed to gather the performance data of the POWER10 + GEMM kernels in `blis/sandbox/power10`. + + By default, the performance of the kernels is gather over a set of square + matrices. The perfromance results are reported in GFLOPS, and outputted in + CSV format. + +*/ + +#include "performance.h" +#include "blis.h" +#include "../bli_sandbox.h" +#include "common.h" + +#include +// print kernel name +const char* get_kernel_name(int kernel_id) +{ + switch (kernel_id) + { + case FLOAT16 : return "bli_shgemm"; + case BFLOAT16: return "bli_sbgemm"; + case INT16 : return "bli_i16gemm"; + case INT8 : return "bli_i8gemm"; + case INT4 : return "bli_i4gemm"; + default: printf("INCORRECT KERNEL ID\n"); exit(-1); + } +} + +// create all the performance gathering functions for each kernel +GET_PERF_API_TEMP(sb, bli_sbgemm, bfloat16, float); +GET_PERF_API_TEMP(sh, bli_shgemm, float16, float); +GET_PERF_API_TEMP(i16, bli_i16gemm, int16_t, int); +GET_PERF_API_TEMP(i8, bli_i8gemm, int8_t, int); +GET_PERF_API_TEMP(i4, bli_i4gemm, nibbles, int); + + +// using the DATATYPE enum, gather the performance of the respective GEMM kernel +double run_kernel(int kernel_id, int nreps, int m, int n, int k) +{ + switch (kernel_id) + { + case FLOAT16 : return test_shapi(nreps, m, n, k); + case BFLOAT16: return test_sbapi(nreps, m, n, k); + case INT16 : return test_i16api(nreps, m, n, k); + case INT8 : return test_i8api(nreps, m, n, k); + case INT4 : return test_i4api(nreps, m, n, k); + default: return -1.0; + } +} + +// print the performance data in CSV format +// performance is measured in terms of GFLOPs +void print_perf_data(int m, int n, int k, double best_time) +{ + double GFLOPS = (2.0 * m * n * k) / (1e9 * best_time); + printf("%d, %d, %d, %.2f\n", m, n, k, GFLOPS); +} + +// get performance data +void get_perf(int kernel_id, int nreps, int start, int end, int inc) +{ + // csv header + printf("%s performance\n", get_kernel_name(kernel_id)); + printf("m, n, k, GFLOPS\n"); + + int m,n,k; + + // run over all problem sizes + for (int p=start; p<=end; p+=inc) + { + // change here to adjust problem size + m = p, + n = p, + k = p; + + double best_run_time = run_kernel(kernel_id, nreps, m, n, k); + + print_perf_data(m, n, k, best_run_time); + } +} + +int main(int argc, char *argv[]) +{ + // initialize a square problem set range + int start = 80; + int end = 4000; + int inc = 80; + + // number of times the kernel will be run + int nreps = 5; + + // run a respective kernel + get_perf( FLOAT16, nreps, start, end, inc); + get_perf(BFLOAT16, nreps, start, end, inc); + get_perf( INT16, nreps, start, end, inc); + get_perf( INT8, nreps, start, end, inc); + get_perf( INT4, nreps, start, end, inc); + + return 0; +} diff --git a/sandbox/power10/p10_testsuite/performance.h b/sandbox/power10/p10_testsuite/performance.h new file mode 100644 index 0000000000..26c36f6155 --- /dev/null +++ b/sandbox/power10/p10_testsuite/performance.h @@ -0,0 +1,58 @@ + +// function name template +// each function that will gather perform will be named test_api +#define GEN_PERF_FUNC_NAME_(ch) test_ ## ch ## api +#define GEN_PERF_FUNC_NAME(ch) GEN_PERF_FUNC_NAME_(ch) + +/* + Macro template for getting the best GEMM kernel runtime out of `num_runs` + for matrices of size (m x n x k). +*/ +#define GET_PERF_API_TEMP(ch, kernel, input_t, output_t) \ +double GEN_PERF_FUNC_NAME(ch) ( \ + int num_runs, \ + int m, \ + int n, \ + int k \ +) \ +{ \ + input_t *A,*B; \ + output_t *C; \ + output_t alpha,beta; \ +\ + A = (input_t*) malloc(m*k*sizeof(input_t)); \ + B = (input_t*) malloc(n*k*sizeof(input_t)); \ + C = (output_t*) malloc(m*n*sizeof(output_t)); \ + \ + alpha = 1; \ + beta = 1; \ + \ + double best = 1e9; \ + \ + for (int irep=0; irep 1, we skip this test. + if [ "${im}" = "eigen" ] && \ + [ "${op}" != "gemm" ] && \ + [ "${nt}" != "1" ]; then + continue; + fi + + # Find the threading suffix by probing the executable. + binname=$(ls ${exec_root}_${dt}${op}_${psize}_${im}_*.x) + suf_ext=${binname##*_} + suf=${suf_ext%%.*} + + #echo "found file: ${binname} with suffix ${suf}" + + # Set the number of threads according to th. + if [ "${suf}" = "1s" ] || [ "${suf}" = "2s" ]; then + + # Set the threading parameters based on the implementation + # that we are preparing to run. + if [ "${im}" = "asm_blis" ] || \ + [ "${im}" = "1m_blis" ]; then + unset OMP_NUM_THREADS + export BLIS_JC_NT=${jc_nt} + export BLIS_PC_NT=${pc_nt} + export BLIS_IC_NT=${ic_nt} + export BLIS_JR_NT=${jr_nt} + export BLIS_IR_NT=${ir_nt} + elif [ "${im}" = "openblas" ]; then + unset OMP_NUM_THREADS + export OPENBLAS_NUM_THREADS=${nt} + elif [ "${im}" = "eigen" ]; then + export OMP_NUM_THREADS=${nt} + elif [ "${im}" = "vendor" ]; then + unset OMP_NUM_THREADS + export MKL_NUM_THREADS=${nt} + fi + export nt_use=${nt} + + # Multithreaded OpenBLAS seems to have a problem running + # properly if GOMP_CPU_AFFINITY is set. So we temporarily + # unset it here if we are about to execute OpenBLAS, but + # otherwise restore it. + if [ ${im} = "openblas" ]; then + unset GOMP_CPU_AFFINITY + else + export GOMP_CPU_AFFINITY="${GOMP_CPU_AFFINITYsave}" + fi + else + + export BLIS_JC_NT=1 + export BLIS_PC_NT=1 + export BLIS_IC_NT=1 + export BLIS_JR_NT=1 + export BLIS_IR_NT=1 + export OMP_NUM_THREADS=1 + export OPENBLAS_NUM_THREADS=1 + export MKL_NUM_THREADS=1 + export nt_use=1 + fi + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${psize}_${im}_${suf}.x" + + # Construct the name of the output file. + out_file="${out_root}_${suf}_${dt}${op}_${im}.m" + + #echo "Running (nt = ${nt_use}) ./${exec_name} > ${out_file}" + echo "Running: ${runcmd} ./${exec_name} > ${out_file}" + + # Run executable. + #./${exec_name} > ${out_file} + #numactl -i all ./${exec_name} > ${out_file} + eval "${runcmd} ./${exec_name} > ${out_file}" + + sleep ${delay} + + done + done + done +done + diff --git a/test/1m4m/test_gemm.c b/test/1m4m/test_gemm.c new file mode 100644 index 0000000000..f9a855125f --- /dev/null +++ b/test/1m4m/test_gemm.c @@ -0,0 +1,421 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + #include + using namespace Eigen; +#else + #include "blis.h" +#endif + +#define COL_STORAGE +//#define ROW_STORAGE + +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + dim_t p; + dim_t p_begin, p_max, p_inc; + int m_input, n_input, k_input; + ind_t ind; + num_t dt; + char dt_ch; + int r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + + //bli_init(); + + bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + + dt = DT; + + ind = IND; + +#if 1 + p_begin = P_BEGIN; + p_max = P_MAX; + p_inc = P_INC; + + m_input = -1; + n_input = -1; + k_input = -1; +#else + p_begin = 40; + p_max = 2000; + p_inc = 40; + + m_input = -1; + n_input = -1; + k_input = -1; +#endif + + + // Supress compiler warnings about unused variable 'ind'. + ( void )ind; + +#if 0 + + cntx_t* cntx; + + ind_t ind_mod = ind; + + // Initialize a context for the current induced method and datatype. + cntx = bli_gks_query_ind_cntx( ind_mod, dt ); + + // Set k to the kc blocksize for the current datatype. + k_input = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); + +#elif 0 + + #ifdef BLIS + if ( ind == BLIS_1M ) k_input = 128; + else k_input = 256; + #else + k_input = 192; + #endif + +#endif + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) + { + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + #ifdef COL_STORAGE + bli_obj_create( dt, m, k, 0, 0, &a ); + bli_obj_create( dt, k, n, 0, 0, &b ); + bli_obj_create( dt, m, n, 0, 0, &c ); + bli_obj_create( dt, m, n, 0, 0, &c_save ); + #else + bli_obj_create( dt, m, k, k, 1, &a ); + bli_obj_create( dt, k, n, n, 1, &b ); + bli_obj_create( dt, m, n, n, 1, &c ); + bli_obj_create( dt, m, n, n, 1, &c_save ); + #endif + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (1.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#ifdef BLIS + bli_ind_disable_all_dt( dt ); + bli_ind_enable_dt( ind, dt ); +#endif + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + #ifdef COL_STORAGE + const int os_a = bli_obj_col_stride( &a ); + const int os_b = bli_obj_col_stride( &b ); + const int os_c = bli_obj_col_stride( &c ); + #else + const int os_a = bli_obj_row_stride( &a ); + const int os_b = bli_obj_row_stride( &b ); + const int os_c = bli_obj_row_stride( &c ); + #endif + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #ifdef COL_STORAGE + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcd_; + #endif + #else + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcd_; + #endif + #endif + #if defined(IS_FLOAT) + Map > A( ( float* )ap, m, k, stride_a ); + Map > B( ( float* )bp, k, n, stride_b ); + Map > C( ( float* )cp, m, n, stride_c ); + #elif defined (IS_DOUBLE) + Map > A( ( double* )ap, m, k, stride_a ); + Map > B( ( double* )bp, k, n, stride_b ); + Map > C( ( double* )cp, m, n, stride_c ); + #elif defined (IS_SCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #elif defined (IS_DCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #endif +#endif + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#if defined(BLIS) + + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + +#elif defined(EIGEN) + + C.noalias() += alpha_r * A * B; + +#else // if defined(BLAS) + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + sgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + dgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + cgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + zgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.1f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, gflops ); + //fflush( stdout ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/3/Makefile b/test/3/Makefile index 86dc25033e..568b7ffb00 100644 --- a/test/3/Makefile +++ b/test/3/Makefile @@ -46,6 +46,7 @@ # .PHONY: all \ + check-env check-env-mk check-lib \ clean cleanx @@ -81,22 +82,11 @@ endif # -# --- BLAS and LAPACK implementations ------------------------------------------ +# --- BLAS implementations ----------------------------------------------------- # -# BLIS library and header path. This is simply wherever it was installed. -#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib -#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis - -# BLIS library. -#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a - # BLAS library path(s). This is where the BLAS libraries reside. -HOME_LIB_PATH := $(HOME)/flame/lib -#VENDOR_LIB_PATH := /opt/apps/intel/13/composer_xe_2013.2.146/mkl/lib/intel64 -MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 -#VENDOR_LIB_PATH := ${MKLROOT}/lib/intel64 -#ICC_LIB_PATH := /opt/apps/intel/13/composer_xe_2013.2.146/compiler/lib/intel64 +HOME_LIB_PATH := $(HOME)/flame/lib # OpenBLAS OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a @@ -106,7 +96,13 @@ OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a #ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ # $(HOME_LIB_PATH)/libatlas.a +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + # MKL +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 MKL_LIB := -L$(MKL_LIB_PATH) \ -lmkl_intel_lp64 \ -lmkl_core \ @@ -174,31 +170,31 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) # Which library? BLI_DEF := -DBLIS BLA_DEF := -DBLAS +EIG_DEF := -DEIGEN # Complex implementation type -D3MHW := -DIND=BLIS_3MH -D3M1 := -DIND=BLIS_3M1 -D4MHW := -DIND=BLIS_4MH -D4M1B := -DIND=BLIS_4M1B -D4M1A := -DIND=BLIS_4M1A D1M := -DIND=BLIS_1M DNAT := -DIND=BLIS_NAT # Implementation string -#STR_3MHW := -DSTR=\"3mhw\" -#STR_3M1 := -DSTR=\"3m1\" -#STR_4MHW := -DSTR=\"4mhw\" -#STR_4M1B := -DSTR=\"4m1b\" -#STR_4M1A := -DSTR=\"4m1a\" #STR_1M := -DSTR=\"1m\" STR_NAT := -DSTR=\"asm_blis\" STR_OBL := -DSTR=\"openblas\" +STR_EIG := -DSTR=\"eigen\" STR_VEN := -DSTR=\"vendor\" # Single or multithreaded string @@ -220,13 +216,14 @@ PDEF_2S := -DP_BEGIN=$(P2_BEGIN) -DP_INC=$(P2_INC) -DP_MAX=$(P2_MAX) all: all-st all-1s all-2s blis: blis-st blis-1s blis-2s openblas: openblas-st openblas-1s openblas-2s +eigen: eigen-st eigen-1s eigen-2s vendor: vendor-st vendor-1s vendor-2s mkl: vendor armpl: vendor -all-st: blis-st openblas-st mkl-st -all-1s: blis-1s openblas-1s mkl-1s -all-2s: blis-2s openblas-2s mkl-2s +all-st: blis-st openblas-st mkl-st eigen-st +all-1s: blis-1s openblas-1s mkl-1s eigen-1s +all-2s: blis-2s openblas-2s mkl-2s eigen-2s blis-st: blis-nat-st blis-1s: blis-nat-1s @@ -236,9 +233,10 @@ blis-2s: blis-nat-2s blis-nat: blis-nat-st blis-nat-1s blis-nat-2s # Define the datatypes, operations, and implementations. -DTS := s d c z -OPS := gemm hemm herk trmm trsm -IMPLS := asm_blis openblas vendor +DTS := s d c z +OPS := gemm hemm herk trmm trsm +BIMPLS := asm_blis openblas vendor +EIMPLS := eigen # Define functions to construct object filenames from the datatypes and # operations given an implementation. We define one function for single- @@ -263,6 +261,13 @@ OPENBLAS_1S_BINS := $(patsubst %.o,%.x,$(OPENBLAS_1S_OBJS)) OPENBLAS_2S_OBJS := $(call get-2s-objs,openblas) OPENBLAS_2S_BINS := $(patsubst %.o,%.x,$(OPENBLAS_2S_OBJS)) +EIGEN_ST_OBJS := $(call get-st-objs,eigen) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) +EIGEN_1S_OBJS := $(call get-1s-objs,eigen) +EIGEN_1S_BINS := $(patsubst %.o,%.x,$(EIGEN_1S_OBJS)) +EIGEN_2S_OBJS := $(call get-2s-objs,eigen) +EIGEN_2S_BINS := $(patsubst %.o,%.x,$(EIGEN_2S_OBJS)) + VENDOR_ST_OBJS := $(call get-st-objs,vendor) VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) VENDOR_1S_OBJS := $(call get-1s-objs,vendor) @@ -271,17 +276,21 @@ VENDOR_2S_OBJS := $(call get-2s-objs,vendor) VENDOR_2S_BINS := $(patsubst %.o,%.x,$(VENDOR_2S_OBJS)) # Define some targets associated with the above object/binary files. -blis-nat-st: $(BLIS_NAT_ST_BINS) -blis-nat-1s: $(BLIS_NAT_1S_BINS) -blis-nat-2s: $(BLIS_NAT_2S_BINS) +blis-nat-st: check-env $(BLIS_NAT_ST_BINS) +blis-nat-1s: check-env $(BLIS_NAT_1S_BINS) +blis-nat-2s: check-env $(BLIS_NAT_2S_BINS) -openblas-st: $(OPENBLAS_ST_BINS) -openblas-1s: $(OPENBLAS_1S_BINS) -openblas-2s: $(OPENBLAS_2S_BINS) +openblas-st: check-env $(OPENBLAS_ST_BINS) +openblas-1s: check-env $(OPENBLAS_1S_BINS) +openblas-2s: check-env $(OPENBLAS_2S_BINS) -vendor-st: $(VENDOR_ST_BINS) -vendor-1s: $(VENDOR_1S_BINS) -vendor-2s: $(VENDOR_2S_BINS) +eigen-st: check-env $(EIGEN_ST_BINS) +eigen-1s: check-env $(EIGEN_1S_BINS) +eigen-2s: check-env $(EIGEN_2S_BINS) + +vendor-st: check-env $(VENDOR_ST_BINS) +vendor-1s: check-env $(VENDOR_1S_BINS) +vendor-2s: check-env $(VENDOR_2S_BINS) mkl-st: vendor-st mkl-1s: vendor-1s @@ -293,53 +302,99 @@ armpl-2s: vendor-2s # Mark the object files as intermediate so that make will remove them # automatically after building the binaries on which they depend. -.INTERMEDIATE: $(BLIS_NAT_ST_OBJS) $(OPENBLAS_ST_OBJS) $(VENDOR_ST_OBJS) -.INTERMEDIATE: $(BLIS_NAT_1S_OBJS) $(OPENBLAS_1S_OBJS) $(VENDOR_1S_OBJS) -.INTERMEDIATE: $(BLIS_NAT_2S_OBJS) $(OPENBLAS_2S_OBJS) $(VENDOR_2S_OBJS) +.INTERMEDIATE: $(BLIS_NAT_ST_OBJS) $(BLIS_NAT_1S_OBJS) $(BLIS_NAT_2S_OBJS) +.INTERMEDIATE: $(OPENBLAS_ST_OBJS) $(OPENBLAS_1S_OBJS) $(OPENBLAS_2S_OBJS) +.INTERMEDIATE: $(EIGEN_ST_OBJS) $(EIGEN_1S_OBJS) $(EIGEN_2S_OBJS) +.INTERMEDIATE: $(VENDOR_ST_OBJS) $(VENDOR_1S_OBJS) $(VENDOR_2S_OBJS) -# --Object file rules -- +# -- Object file rules -- #$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c # $(CC) $(CFLAGS) -c $< -o $@ # A function to return the datatype cpp macro def from the datatype # character. -get-dt-cpp = -DDT=bli_$(1)type +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) # A function to return other cpp macros that help the test driver # identify the implementation. +#get-bl-cpp = $(strip \ +# $(if $(findstring blis,$(1)),$(STR_NAT) $(BLI_DEF),\ +# $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ +# $(if $(findstring eigen,$(1)),$(STR_EIG) $(EIG_DEF),\ +# $(STR_VEN) $(BLA_DEF))))) + get-bl-cpp = $(strip \ $(if $(findstring blis,$(1)),$(STR_NAT) $(BLI_DEF),\ $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ - $(STR_VEN) $(BLA_DEF)))) + $(if $(and $(findstring eigen,$(1)),\ + $(findstring gemm,$(2))),\ + $(STR_EIG) $(EIG_DEF),\ + $(if $(findstring eigen,$(1)),\ + $(STR_EIG) $(BLA_DEF),\ + $(STR_VEN) $(BLA_DEF)))))) + +# Rules for BLIS and BLAS libraries. define make-st-rule test_$(1)$(2)_$(PS_MAX)_$(3)_st.o: test_$(op).c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3)) $(DNAT) $(STR_ST) -c $$< -o $$@ + $(CC) $(CFLAGS) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_ST) -c $$< -o $$@ endef define make-1s-rule test_$(1)$(2)_$(P1_MAX)_$(3)_1s.o: test_$(op).c Makefile - $(CC) $(CFLAGS) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3)) $(DNAT) $(STR_1S) -c $$< -o $$@ + $(CC) $(CFLAGS) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_1S) -c $$< -o $$@ endef define make-2s-rule test_$(1)$(2)_$(P2_MAX)_$(3)_2s.o: test_$(op).c Makefile - $(CC) $(CFLAGS) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3)) $(DNAT) $(STR_2S) -c $$< -o $$@ + $(CC) $(CFLAGS) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_2S) -c $$< -o $$@ +endef + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-st-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-1s-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-2s-rule,$(dt),$(op),$(im)))))) + +# Rules for Eigen. +define make-eigst-rule +test_$(1)$(2)_$(PS_MAX)_$(3)_st.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_ST) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_ST) -c $$< -o $$@ +endef + +define make-eig1s-rule +test_$(1)$(2)_$(P1_MAX)_$(3)_1s.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_MT) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_1S) -c $$< -o $$@ +endef + +define make-eig2s-rule +test_$(1)$(2)_$(P2_MAX)_$(3)_2s.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_MT) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_2S) -c $$< -o $$@ endef $(foreach dt,$(DTS), \ $(foreach op,$(OPS), \ -$(foreach im,$(IMPLS),$(eval $(call make-st-rule,$(dt),$(op),$(im)))))) +$(foreach im,$(EIMPLS),$(eval $(call make-eigst-rule,$(dt),$(op),$(im)))))) $(foreach dt,$(DTS), \ $(foreach op,$(OPS), \ -$(foreach im,$(IMPLS),$(eval $(call make-1s-rule,$(dt),$(op),$(im)))))) +$(foreach im,$(EIMPLS),$(eval $(call make-eig1s-rule,$(dt),$(op),$(im)))))) $(foreach dt,$(DTS), \ $(foreach op,$(OPS), \ -$(foreach im,$(IMPLS),$(eval $(call make-2s-rule,$(dt),$(op),$(im)))))) +$(foreach im,$(EIMPLS),$(eval $(call make-eig2s-rule,$(dt),$(op),$(im)))))) # -- Executable file rules -- @@ -349,34 +404,59 @@ $(foreach im,$(IMPLS),$(eval $(call make-2s-rule,$(dt),$(op),$(im)))))) # compatibility layer. This prevents BLIS from inadvertently getting called # for the BLAS routines we are trying to test with. +test_%_$(PS_MAX)_asm_blis_st.x: test_%_$(PS_MAX)_asm_blis_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_asm_blis_1s.x: test_%_$(P1_MAX)_asm_blis_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_asm_blis_2s.x: test_%_$(P2_MAX)_asm_blis_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + test_%_$(PS_MAX)_openblas_st.x: test_%_$(PS_MAX)_openblas_st.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) test_%_$(P1_MAX)_openblas_1s.x: test_%_$(P1_MAX)_openblas_1s.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) test_%_$(P2_MAX)_openblas_2s.x: test_%_$(P2_MAX)_openblas_2s.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(PS_MAX)_vendor_st.x: test_%_$(PS_MAX)_vendor_st.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) +test_%_$(PS_MAX)_eigen_st.x: test_%_$(PS_MAX)_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGEN_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(P1_MAX)_vendor_1s.x: test_%_$(P1_MAX)_vendor_1s.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) +test_%_$(P1_MAX)_eigen_1s.x: test_%_$(P1_MAX)_eigen_1s.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGENP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(P2_MAX)_vendor_2s.x: test_%_$(P2_MAX)_vendor_2s.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) +test_%_$(P2_MAX)_eigen_2s.x: test_%_$(P2_MAX)_eigen_2s.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGENP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(PS_MAX)_asm_blis_st.x: test_%_$(PS_MAX)_asm_blis_st.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) +test_%_$(PS_MAX)_vendor_st.x: test_%_$(PS_MAX)_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(P1_MAX)_asm_blis_1s.x: test_%_$(P1_MAX)_asm_blis_1s.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) +test_%_$(P1_MAX)_vendor_1s.x: test_%_$(P1_MAX)_vendor_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(P2_MAX)_asm_blis_2s.x: test_%_$(P2_MAX)_asm_blis_2s.o $(LIBBLIS_LINK) - $(LINKER) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) +test_%_$(P2_MAX)_vendor_2s.x: test_%_$(P2_MAX)_vendor_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Environment check rules -- + +check-env: check-lib + +check-env-mk: +ifeq ($(CONFIG_MK_PRESENT),no) + $(error Cannot proceed: config.mk not detected! Run configure first) +endif + +check-lib: check-env-mk +ifeq ($(wildcard $(LIBBLIS_LINK)),) + $(error Cannot proceed: BLIS library not yet built! Run make first) +endif # -- Clean rules -- diff --git a/test/3/matlab/plot_panel_4x5.m b/test/3/matlab/plot_panel_4x5.m deleted file mode 100644 index 40e212a687..0000000000 --- a/test/3/matlab/plot_panel_4x5.m +++ /dev/null @@ -1,102 +0,0 @@ -function r_val = plot_panel_4x5( cfreq, ... - dflopspercycle, ... - nth, ... - thr_str, ... - dirpath, ... - arch_str, ... - vend_str ) - -%cfreq = 1.8; -%dflopspercycle = 32; - -% Create filename "templates" for the files that contain the performance -% results. -filetemp_blis = '%s/output_%s_%s_asm_blis.m'; -filetemp_open = '%s/output_%s_%s_openblas.m'; -filetemp_vend = '%s/output_%s_%s_vendor.m'; - -% Create a variable name "template" for the variables contained in the -% files outlined above. -vartemp = 'data_%s_%s_%s( :, : )'; - -% Define the datatypes and operations we will be plotting. -dts = [ 's' 'd' 'c' 'z' ]; -ops( 1, : ) = 'gemm'; -ops( 2, : ) = 'hemm'; -ops( 3, : ) = 'herk'; -ops( 4, : ) = 'trmm'; -ops( 5, : ) = 'trsm'; - -% Generate datatype-specific operation names from the set of operations -% and datatypes. -opnames = gen_opnames( ops, dts ); -n_opnames = size( opnames, 1 ); - -fig = figure('Position', [100, 100, 2000, 1500]); -orient( fig, 'portrait' ); -set(gcf,'PaperUnits', 'inches'); -if 1 == 1 % matlab - set(gcf,'PaperSize', [11 15.0]); - set(gcf,'PaperPosition', [0 0 11 15.0]); - set(gcf,'PaperPositionMode','manual'); -else % octave 4.x - set(gcf,'PaperSize', [15 19.0]); - set(gcf,'PaperPositionMode','auto'); -end -set(gcf,'PaperOrientation','landscape'); - - -% Iterate over the list of datatype-specific operation names. -for opi = 1:n_opnames -%for opi = 1:1 - - % Grab the current datatype combination. - opname = opnames( opi, : ); - - str = sprintf( 'Plotting %d: %s', opi, opname ); disp(str); - - % Construct filenames for the data files from templates. - file_blis = sprintf( filetemp_blis, dirpath, thr_str, opname ); - file_open = sprintf( filetemp_open, dirpath, thr_str, opname ); - file_vend = sprintf( filetemp_vend, dirpath, thr_str, opname ); - - % Load the data files. - %str = sprintf( ' Loading %s', file_blis ); disp(str); - run( file_blis ) - %str = sprintf( ' Loading %s', file_open ); disp(str); - run( file_open ) - %str = sprintf( ' Loading %s', file_vend ); disp(str); - run( file_vend ) - - % Construct variable names for the variables in the data files. - var_blis = sprintf( vartemp, thr_str, opname, 'asm_blis' ); - var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); - var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); - - % Use eval() to instantiate the variable names constructed above, - % copying each to a simplified name. - data_blis = eval( var_blis ); % e.g. data_st_sgemm_asm_blis( :, : ); - data_open = eval( var_open ); % e.g. data_st_sgemm_openblas( :, : ); - data_vend = eval( var_vend ); % e.g. data_st_sgemm_vendor( :, : ); - - % Plot one result in an m x n grid of plots, via the subplot() - % function. - plot_l3_perf( opname, ... - data_blis, ... - data_open, ... - data_vend, vend_str, ... - nth, ... - 4, 5, ... - cfreq, ... - dflopspercycle, ... - opi ); - -end - -% Construct the name of the file to which we will output the graph. -outfile = sprintf( 'l3_perf_%s_nt%d.pdf', arch_str, nth ); - -% Output the graph to pdf format. -%print(gcf, 'gemm_md','-fillpage','-dpdf'); -print(gcf, outfile,'-bestfit','-dpdf'); - diff --git a/test/3/matlab/runme.m b/test/3/matlab/runme.m deleted file mode 100644 index 2da7d74428..0000000000 --- a/test/3/matlab/runme.m +++ /dev/null @@ -1,19 +0,0 @@ -% tx2 -plot_panel_4x5(2.20,8,1, 'st','../results/tx2/20190205/st', 'tx2', 'ARMPL'); close; clear all; -plot_panel_4x5(2.20,8,28,'1s','../results/tx2/20190205/jc4ic7','tx2_jc4ic7','ARMPL'); close; clear all; -plot_panel_4x5(2.20,8,56,'2s','../results/tx2/20190205/jc8ic7','tx2_jc8ic7','ARMPL'); close; clear all; - -% skx -plot_panel_4x5(2.00,32,1, 'st','../results/skx/20190306/st', 'skx', 'MKL'); close; clear all; -plot_panel_4x5(2.00,32,26,'1s','../results/skx/20190306/jc2ic13','skx_jc2ic13','MKL'); close; clear all; -plot_panel_4x5(2.00,32,52,'2s','../results/skx/20190306/jc4ic13','skx_jc4ic13','MKL'); close; clear all; - -% has -plot_panel_4x5(3.25,16,1, 'st','../results/has/20190206/st', 'has', 'MKL'); close; clear all; -plot_panel_4x5(3.00,16,12,'1s','../results/has/20190206/jc2ic3jr2','has_jc2ic3jr2','MKL'); close; clear all; -plot_panel_4x5(3.00,16,24,'2s','../results/has/20190206/jc4ic3jr2','has_jc4ic3jr2','MKL'); close; clear all; - -% epyc -plot_panel_4x5(3.00,8,1, 'st','../results/epyc/20190306/st', 'epyc', 'MKL'); close; clear all; -plot_panel_4x5(2.55,8,32,'1s','../results/epyc/20190306/jc1ic8jr4','epyc_jc1ic8jr4','MKL'); close; clear all; -plot_panel_4x5(2.55,8,64,'2s','../results/epyc/20190306/jc2ic8jr4','epyc_jc2ic8jr4','MKL'); close; clear all; diff --git a/test/3/matlab/gen_opnames.m b/test/3/octave/gen_opnames.m similarity index 98% rename from test/3/matlab/gen_opnames.m rename to test/3/octave/gen_opnames.m index 26b8e84bb2..814e6e9235 100644 --- a/test/3/matlab/gen_opnames.m +++ b/test/3/octave/gen_opnames.m @@ -20,4 +20,3 @@ r_val = opnames; -end diff --git a/test/3/matlab/plot_l3_perf.m b/test/3/octave/plot_l3_perf.m similarity index 51% rename from test/3/matlab/plot_l3_perf.m rename to test/3/octave/plot_l3_perf.m index 8717fb5ebc..4ac7ab73b6 100644 --- a/test/3/matlab/plot_l3_perf.m +++ b/test/3/octave/plot_l3_perf.m @@ -1,21 +1,29 @@ function r_val = plot_l3_perf( opname, ... data_blis, ... data_open, ... + data_eige, ... data_vend, vend_str, ... nth, ... rows, cols, ... cfreq, ... dfps, ... - theid ) + theid, ... + leg_pos_st, leg_pos_mt, ... + sp_margins ) -if 1 -ax1 = subplot( rows, cols, theid ); -hold( ax1, 'on' ); -end +% Define the column in which the performance rates are found. +flopscol = size( data_blis, 2 ); + +% Define which plot id will have the legend. +% NOTE: We can draw the legend on any graph as long as it has already been +% rendered. Since the coordinates are global, we can simply always wait until +% the final graph to draw the legend. +legend_plot_id = cols*rows; % Set line properties. color_blis = 'k'; lines_blis = '-'; markr_blis = ''; color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_eige = 'm'; lines_eige = '-.'; markr_eige = 'x'; color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; % Compute the peak performance in terms of the number of double flops @@ -48,14 +56,11 @@ % Set the legend strings. blis_legend = sprintf( 'BLIS' ); open_legend = sprintf( 'OpenBLAS' ); +eige_legend = sprintf( 'Eigen' ); %vend_legend = sprintf( 'MKL' ); %vend_legend = sprintf( 'ARMPL' ); vend_legend = vend_str; -% Determine the final dimension. -%n_points = size( data_blis, 1 ); -%x_end = data_blis( n_points, 1 ); - % Set axes range values. y_scale = 1.00; x_begin = 0; @@ -64,25 +69,30 @@ y_end = max_perf_core * y_scale; % Set axes names. -xaxisname = ' m = n = k'; if nth == 1 yaxisname = 'GFLOPS'; else yaxisname = 'GFLOPS/core'; end - -%flopscol = 4; -flopscol = size( data_blis, 2 ); +% Set the marker size, font size, and other items. msize = 5; -if 1 -fontsize = 13; +if 0 + xaxisname = ' m = n = k'; + fontsize = 12; else -fontsize = 16; + xaxisname = 'm = n = k'; + fontsize = 20; end -linesize = 0.5; +linesize = 0.8; legend_loc = 'southeast'; +%ax1 = subplot( rows, cols, theid ); +ax1 = subplot_tight( rows, cols, theid, sp_margins ); + +% Hold the axes. +hold( ax1, 'on' ); + % -------------------------------------------------------------------- x_axis( :, 1 ) = data_blis( :, 1 ); @@ -90,15 +100,43 @@ data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; +% Plot the data series for BLIS, which is required. blis_ln = line( x_axis( :, 1 ), data_blis( :, flopscol ) / nth, ... - 'Color',color_blis, 'LineStyle',lines_blis, ... - 'LineWidth',linesize ); -open_ln = line( x_axis( :, 1 ), data_open( :, flopscol ) / nth, ... - 'Color',color_open, 'LineStyle',lines_open, ... - 'LineWidth',linesize ); -vend_ln = line( x_axis( :, 1 ), data_vend( :, flopscol ) / nth, ... - 'Color',color_vend, 'LineStyle',lines_vend, ... - 'LineWidth',linesize ); + 'Color',color_blis, 'LineStyle',lines_blis, ... + 'LineWidth',linesize ); + +% Plot the data series for OpenBLAS, if applicable. +if data_open(1,1) ~= -1 + open_ln = line( x_axis( :, 1 ), data_open( :, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +else + open_ln = line( nan, nan, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +end + +% Plot the data series for a vendor library, if applicable. +if data_vend(1,1) ~= -1 + vend_ln = line( x_axis( :, 1 ), data_vend( :, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +else + vend_ln = line( nan, nan, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +end + +% Plot the data series for Eigen, if applicable. +if data_eige(1,1) ~= -1 + eige_ln = line( x_axis( :, 1 ), data_eige( :, flopscol ) / nth, ... + 'Color',color_eige, 'LineStyle',lines_eige, ... + 'LineWidth',linesize ); +else + eige_ln = line( nan, nan, ... + 'Color',color_eige, 'LineStyle',lines_eige, ... + 'LineWidth',linesize ); +end xlim( ax1, [x_begin x_end] ); @@ -120,44 +158,21 @@ if rows == 4 && cols == 5 - if nth == 1 && theid == 3 - leg = legend( ... - [ ... - blis_ln ... - open_ln ... - vend_ln ... - ], ... - blis_legend, ... - open_legend, ... - vend_legend, ... - 'Location', legend_loc ); - set( leg,'Box','off' ); - set( leg,'Color','none' ); - set( leg,'FontSize',fontsize-3 ); - set( leg,'Units','inches' ); - set( leg,'Position',[11.20 12.75 0.7 0.3 ] ); % (0,2br) - elseif nth > 1 && theid == 4 - leg = legend( ... - [ ... - blis_ln ... - open_ln ... - vend_ln ... - ], ... - blis_legend, ... - open_legend, ... - vend_legend, ... - 'Location', legend_loc ); - set( leg,'Box','off' ); - set( leg,'Color','none' ); - set( leg,'FontSize',fontsize-3 ); - set( leg,'Units','inches' ); - %set( leg,'Position',[7.70 12.75 0.7 0.3 ] ); % (0,1br) - %set( leg,'Position',[10.47 14.28 0.7 0.3 ] ); % (0,2tl) - set( leg,'Position',[11.20 12.75 0.7 0.3 ] ); % (0,2br) - %set( leg,'Position',[13.95 14.28 0.7 0.3 ] ); % (0,3tl) - %set( leg,'Position',[14.70 12.75 0.7 0.3 ] ); % (0,3br) - %set( leg,'Position',[17.45 14.28 0.7 0.3 ] ); % (0,4tl) - %set( leg,'Position',[18.22 12.75 0.7 0.3 ] ); % (0,4br) + if nth == 1 && theid == legend_plot_id + + leg = legend( [ blis_ln vend_ln open_ln eige_ln ], ... + blis_legend, vend_legend, open_legend, eige_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches','FontSize',fontsize ); + set( leg,'Position',leg_pos_st ); + + elseif nth > 1 && theid == legend_plot_id + + leg = legend( [ blis_ln vend_ln open_ln eige_ln ], ... + blis_legend, vend_legend, open_legend, eige_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches','FontSize',fontsize ); + set( leg,'Position',leg_pos_mt ); end end @@ -172,23 +187,18 @@ %tpos(1) = tpos(1) + 100; tpos(1) = tpos(1) + 40; set( titl, 'Position', tpos ); % here we nudge it back to centered with box. +set( titl, 'FontSize', fontsize ); if theid > (rows-1)*cols -xlab = xlabel( ax1,xaxisname ); -%tpos = get( xlab, 'Position' ) -%tpos(2) = tpos(2) + 10; -%set( xlab, 'Position', tpos ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + xlab = xlabel( ax1,xaxisname ); end if mod(theid-1,cols) == 0 -ylab = ylabel( ax1,yaxisname ); + ylab = ylabel( ax1,yaxisname ); end -%export_fig( filename, colorflag, '-pdf', '-m2', '-painters', '-transparent' ); -%saveas( fig, filename_png ); - -%hold( ax1, 'off' ); - r_val = 0; -end diff --git a/test/3/octave/plot_panel_4x5.m b/test/3/octave/plot_panel_4x5.m new file mode 100644 index 0000000000..3c21feb636 --- /dev/null +++ b/test/3/octave/plot_panel_4x5.m @@ -0,0 +1,128 @@ +function r_val = plot_panel_4x5 ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dirpath, ... + arch_str, ... + vend_leg_str ... + ) + +impl = 'octave'; +%impl = 'matlab'; + +%sp = 'default'; +subp = 'tight'; + +if strcmp( subp, 'default' ) + position = [100 100 2000 1500]; + papersize = [14.2 19.0]; + leg_pos_st = [3.40 8.70 1.9 1.0 ]; % (0,2br) + leg_pos_mt = [13.08 13.09 1.9 1.0 ]; % (0,3tr) + sp_margins = [ 0.070 0.049 ]; +else + position = [100 100 1864 1540]; + papersize = [15.6 19.4]; + %leg_pos_st = [1.15 8.70 2.1 1.2 ]; % (dgemm) + %leg_pos_st = [1.60 8.80 2.1 1.2 ]; % (dgemm) + leg_pos_st = [15.90 13.60 2.1 1.2 ]; % (strsm) + %leg_pos_mt = [12.20 13.60 2.1 1.2 ]; % (strmm) + %leg_pos_mt = [5.30 12.60 2.1 1.2 ]; % (ssymm) + %leg_pos_mt = [8.50 13.62 2.1 1.2 ]; % (ssyrk) + %leg_pos_mt = [5.30 5.10 2.1 1.2 ]; % (chemm) + leg_pos_mt = [15.90 13.60 2.1 1.2 ]; % (strsm) + sp_margins = [ 0.068 0.051 ]; +end + +%fig = figure('Position', [100, 100, 2000, 1500]); +fig = figure('Position', position); +orient( fig, 'portrait' ); +set(gcf,'PaperUnits', 'inches'); +if strcmp( impl, 'octave' ) + %set(gcf,'PaperSize', [14.2 19.0]); + set(gcf,'PaperSize', papersize); + %set(gcf,'PaperPositionMode','auto'); + set(gcf,'PaperPositionMode','auto'); +else % impl == 'matlab' + set(gcf,'PaperSize', [13 20.0]); + set(gcf,'PaperPosition', [0 0 13 20.0]); + set(gcf,'PaperPositionMode','manual'); +end +set(gcf,'PaperOrientation','landscape'); + +% Define the implementation strings. These appear in both the filenames of the +% files that contain the performance results as well as the variable names +% within those files. +blis_str = 'asm_blis'; +open_str = 'openblas'; +vend_str = 'vendor'; +eige_str = 'eigen'; + +% Create filename "templates" for the files that contain the performance +% results. +filetemp = '%s/output_%s_%s_%s.m'; +filetemp_blis = sprintf( filetemp, '%s', '%s', '%s', blis_str ); +filetemp_open = sprintf( filetemp, '%s', '%s', '%s', open_str ); +filetemp_vend = sprintf( filetemp, '%s', '%s', '%s', vend_str ); +filetemp_eige = sprintf( filetemp, '%s', '%s', '%s', eige_str ); + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +dts = [ 's' 'd' 'c' 'z' ]; +ops( 1, : ) = 'gemm'; +ops( 2, : ) = 'hemm'; +ops( 3, : ) = 'herk'; +ops( 4, : ) = 'trmm'; +ops( 5, : ) = 'trsm'; + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +opnames = gen_opnames( ops, dts ); +n_opnames = size( opnames, 1 ); + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opnames +%for opi = 1:1 + + % Grab the current datatype combination. + opname = opnames( opi, : ); + + str = sprintf( 'Plotting %d: %s', opi, opname ); disp(str); + + data_blis = read_data( filetemp_blis, dirpath, vartemp, thr_str, opname, blis_str ); + data_open = read_data( filetemp_open, dirpath, vartemp, thr_str, opname, open_str ); + data_vend = read_data( filetemp_vend, dirpath, vartemp, thr_str, opname, vend_str ); + data_eige = read_data( filetemp_eige, dirpath, vartemp, thr_str, opname, eige_str ); + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + plot_l3_perf( opname, ... + data_blis, ... + data_open, ... + data_eige, ... + data_vend, vend_leg_str, ... + nth, ... + 4, 5, ... + cfreq, ... + dflopspercycle, ... + opi, ... + leg_pos_st, leg_pos_mt, ... + sp_margins ); + +end + + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3_perf_%s_nt%d.pdf', arch_str, nth ); + +% Output the graph to pdf format. +if strcmp( impl, 'octave' ) + print( gcf, outfile ); +else + print( gcf, outfile, '-bestfit', '-dpdf' ); +end + diff --git a/test/3/octave/read_data.m b/test/3/octave/read_data.m new file mode 100644 index 0000000000..28307e6afb --- /dev/null +++ b/test/3/octave/read_data.m @@ -0,0 +1,44 @@ +function data = read_data ... + ( ... + filetemp, ... + dirpath, ... + var_templ, ... + thr_str, ... + opname, ... + impl_str ... + ) + +% Construct the full filepath for the data file from the template. +filepath = sprintf( filetemp, dirpath, thr_str, opname ); + +% Attempt to open the file. +fid = fopen( filepath ); + +if fid == -1 + % If the file was not opened successfully, it's probably because + % the file is missing altogether. In these sitautions, we set the + % first element of the data to -1, which will be a signal to the + % plotting function to omit this curve from the graph. + data(1,1) = -1; +else + % If the file was opened successfully, we assume that it either + % contains valid data, or it adheres to the "missing data" format + % whereby the (1,1) element contains -1. In either case, we can + % process it normally and we begin by closing the file since we + % don't need the file descriptor. + fclose( fid ); + + % Load the data file. + run( filepath ) + + % Construct variable names for the variables in the data file. + % Examples: data_st_dgemm_asm_blis + % data_1s_zherk_vendor + var_name = sprintf( var_templ, thr_str, opname, impl_str ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data = eval( var_name ); +end + +% Return the 'data' variable. diff --git a/test/3/octave/runthese.m b/test/3/octave/runthese.m new file mode 100644 index 0000000000..6a88d8b32e --- /dev/null +++ b/test/3/octave/runthese.m @@ -0,0 +1,33 @@ +% tx2 +plot_panel_4x5(2.20,8,1, 'st','../results/tx2/20190205/st', 'tx2', 'ARMPL'); close; clear all; +plot_panel_4x5(2.20,8,28,'1s','../results/tx2/20190205/jc4ic7','tx2_jc4ic7','ARMPL'); close; clear all; +plot_panel_4x5(2.20,8,56,'2s','../results/tx2/20190205/jc8ic7','tx2_jc8ic7','ARMPL'); close; clear all; + +% skx +plot_panel_4x5(2.00,32,1, 'st','../results/skx/merged20190306_0328/st', 'skx', 'MKL'); close; clear all; +plot_panel_4x5(2.00,32,26,'1s','../results/skx/merged20190306_0328/jc2ic13','skx_jc2ic13','MKL'); close; clear all; +plot_panel_4x5(2.00,32,52,'2s','../results/skx/merged20190306_0328/jc4ic13','skx_jc4ic13','MKL'); close; clear all; + +% has +plot_panel_4x5(3.25,16,1, 'st','../results/has/merged20190206_0328/st', 'has', 'MKL'); close; clear all; +plot_panel_4x5(3.00,16,12,'1s','../results/has/merged20190206_0328/jc2ic3jr2','has_jc2ic3jr2','MKL'); close; clear all; +plot_panel_4x5(3.00,16,24,'2s','../results/has/merged20190206_0328/jc4ic3jr2','has_jc4ic3jr2','MKL'); close; clear all; + +% zen +plot_panel_4x5(3.00,8,1, 'st','../results/epyc/merged20190306_0319_0328/st', 'epyc', 'MKL'); close; clear all; +plot_panel_4x5(2.55,8,32,'1s','../results/epyc/merged20190306_0319_0328/jc1ic8jr4','epyc_jc1ic8jr4','MKL'); close; clear all; +plot_panel_4x5(2.55,8,64,'2s','../results/epyc/merged20190306_0319_0328/jc2ic8jr4','epyc_jc2ic8jr4','MKL'); close; clear all; + +% zen2 +plot_panel_4x5(3.40,16,1, 'st','../results/zen2/20200929/st', 'zen2','MKL'); close all; clear all; +plot_panel_4x5(2.60,16,64, '1s','../results/zen2/20200929/jc4ic4jr4','zen2','MKL'); close all; clear all; +plot_panel_4x5(2.60,16,128,'2s','../results/zen2/20200929/jc8ic4jr4','zen2','MKL'); close all; clear all; + +% a64fx +plot_panel_4x5(2.20,32,1, 'st','../results/a64fx/20210520/st', 'a64fx','Fujitsu SSL2'); close all; clear all; +plot_panel_4x5(2.20,32,12,'1s','../results/a64fx/20210520/jc1ic1jr12','a64fx','Fujitsu SSL2'); close all; clear all; +plot_panel_4x5(2.20,32,48,'2s','../results/a64fx/20210520/jc1ic4jr12','a64fx','Fujitsu SSL2'); close all; clear all; + +% nn1 +plot_panel_4x5(2.50,8,1, 'st','../results/neoverse_n1/20210715/st', 'nn1','ARMPL'); close; clear all; +plot_panel_4x5(2.50,8,64,'1s','../results/neoverse_n1/20210715/nt64','nn1','ARMPL'); close; clear all; diff --git a/test/3/octave/subplot_tight.m b/test/3/octave/subplot_tight.m new file mode 100644 index 0000000000..d84ea31888 --- /dev/null +++ b/test/3/octave/subplot_tight.m @@ -0,0 +1,126 @@ +% +% Copyright (c) 2016, Nikolay S. +% All rights reserved. +% +% Redistribution and use in source and binary forms, with or without +% modification, are permitted provided that the following conditions are +% met: +% +% * Redistributions of source code must retain the above copyright +% notice, this list of conditions and the following disclaimer. +% * Redistributions in binary form must reproduce the above copyright +% notice, this list of conditions and the following disclaimer in +% the documentation and/or other materials provided with the distribution +% +% THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +% AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +% IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +% ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +% LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +% CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +% SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +% INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +% CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +% ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +% POSSIBILITY OF SUCH DAMAGE. +% + +function vargout=subplot_tight(m, n, p, margins, varargin) +%% subplot_tight +% A subplot function substitude with margins user tunabble parameter. +% +%% Syntax +% h=subplot_tight(m, n, p); +% h=subplot_tight(m, n, p, margins); +% h=subplot_tight(m, n, p, margins, subplotArgs...); +% +%% Description +% Our goal is to grant the user the ability to define the margins between neighbouring +% subplots. Unfotrtunately Matlab subplot function lacks this functionality, and the +% margins between subplots can reach 40% of figure area, which is pretty lavish. While at +% the begining the function was implememnted as wrapper function for Matlab function +% subplot, it was modified due to axes del;etion resulting from what Matlab subplot +% detected as overlapping. Therefore, the current implmenetation makes no use of Matlab +% subplot function, using axes instead. This can be problematic, as axis and subplot +% parameters are quie different. Set isWrapper to "True" to return to wrapper mode, which +% fully supports subplot format. +% +%% Input arguments (defaults exist): +% margins- two elements vector [vertical,horizontal] defining the margins between +% neighbouring axes. Default value is 0.04 +% +%% Output arguments +% same as subplot- none, or axes handle according to function call. +% +%% Issues & Comments +% - Note that if additional elements are used in order to be passed to subplot, margins +% parameter must be defined. For default margins value use empty element- []. +% - +% +%% Example +% close all; +% img=imread('peppers.png'); +% figSubplotH=figure('Name', 'subplot'); +% figSubplotTightH=figure('Name', 'subplot_tight'); +% nElems=17; +% subplotRows=ceil(sqrt(nElems)-1); +% subplotRows=max(1, subplotRows); +% subplotCols=ceil(nElems/subplotRows); +% for iElem=1:nElems +% figure(figSubplotH); +% subplot(subplotRows, subplotCols, iElem); +% imshow(img); +% figure(figSubplotTightH); +% subplot_tight(subplotRows, subplotCols, iElem, [0.0001]); +% imshow(img); +% end +% +%% See also +% - subplot +% +%% Revision history +% First version: Nikolay S. 2011-03-29. +% Last update: Nikolay S. 2012-05-24. +% +% *List of Changes:* +% 2012-05-24 +% Non wrapping mode (based on axes command) added, to deal with an issue of disappearing +% subplots occuring with massive axes. + +%% Default params +isWrapper=false; +if (nargin<4) || isempty(margins) + margins=[0.04,0.04]; % default margins value- 4% of figure +end +if length(margins)==1 + margins(2)=margins; +end + +%note n and m are switched as Matlab indexing is column-wise, while subplot indexing is row-wise :( +[subplot_col,subplot_row]=ind2sub([n,m],p); + + +height=(1-(m+1)*margins(1))/m; % single subplot height +width=(1-(n+1)*margins(2))/n; % single subplot width + +% note subplot suppors vector p inputs- so a merged subplot of higher dimentions will be created +subplot_cols=1+max(subplot_col)-min(subplot_col); % number of column elements in merged subplot +subplot_rows=1+max(subplot_row)-min(subplot_row); % number of row elements in merged subplot + +merged_height=subplot_rows*( height+margins(1) )- margins(1); % merged subplot height +merged_width= subplot_cols*( width +margins(2) )- margins(2); % merged subplot width + +merged_bottom=(m-max(subplot_row))*(height+margins(1)) +margins(1); % merged subplot bottom position +merged_left=min(subplot_col)*(width+margins(2))-width; % merged subplot left position +pos=[merged_left, merged_bottom, merged_width, merged_height]; + + +if isWrapper + h=subplot(m, n, p, varargin{:}, 'Units', 'Normalized', 'Position', pos); +else + h=axes('Position', pos, varargin{:}); +end + +if nargout==1 + vargout=h; +end diff --git a/test/3/runme.sh b/test/3/runme.sh index aeed7d98b9..56c1928097 100755 --- a/test/3/runme.sh +++ b/test/3/runme.sh @@ -5,11 +5,12 @@ exec_root="test" out_root="output" delay=0.1 -sys="blis" +#sys="blis" #sys="stampede2" #sys="lonestar5" #sys="ul252" #sys="ul264" +sys="ul2128" # Bind threads to processors. #export OMP_PROC_BIND=true @@ -18,27 +19,31 @@ sys="blis" if [ ${sys} = "blis" ]; then - export GOMP_CPU_AFFINITY="0 1 2 3" + export GOMP_CPU_AFFINITY="0-3" + numactl="" threads="jc1ic1jr1_2400 - jc2ic2jr1_4000" + jc2ic3jr2_6000 + jc4ic3jr2_8000" elif [ ${sys} = "stampede2" ]; then echo "Need to set GOMP_CPU_AFFINITY." exit 1 + numactl="" threads="jc1ic1jr1_2400 jc4ic6jr1_6000 jc4ic12jr1_8000" elif [ ${sys} = "lonestar5" ]; then - export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23" + export GOMP_CPU_AFFINITY="0-23" # A hack to use libiomp5 with gcc. #export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/apps/intel/16.0.1.150/compilers_and_libraries_2016.1.150/linux/compiler/lib/intel64" + numactl="" threads="jc1ic1jr1_2400 jc2ic3jr2_6000 jc4ic3jr2_8000" @@ -46,8 +51,9 @@ elif [ ${sys} = "lonestar5" ]; then elif [ ${sys} = "ul252" ]; then export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" - export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51" + export GOMP_CPU_AFFINITY="0-51" + numactl="" threads="jc1ic1jr1_2400 jc2ic13jr1_6000 jc4ic13jr1_8000" @@ -55,36 +61,67 @@ elif [ ${sys} = "ul252" ]; then elif [ ${sys} = "ul264" ]; then export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" - export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63" + export GOMP_CPU_AFFINITY="0-63" + numactl="numactl --interleave=all" threads="jc1ic1jr1_2400 jc1ic8jr4_6000 jc2ic8jr4_8000" +elif [ ${sys} = "ul2128" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0-127" + + numactl="numactl --interleave=all" + threads="jc1ic1jr1_2400 + jc4ic4jr4_6000 + jc8ic4jr4_8000" + #threads="jc4ic4jr4_6000 + # jc8ic4jr4_8000" + #threads="jc1ic1jr1_2400" + #threads="jc4ic4jr4_6000" + #threads="jc8ic4jr4_8000" fi # Datatypes to test. test_dts="d s z c" +#test_dts="s" # Operations to test. test_ops="gemm hemm herk trmm trsm" +#test_ops="herk" # Implementations to test. -impls="all" -#impls="other" #impls="blis" +#impls="openblas" +#impls="vendor" +#impls="other" +#impls="eigen" +impls="all" if [ "${impls}" = "blis" ]; then test_impls="asm_blis" -elif [ "${impls}" = "other" ]; then +elif [ "${impls}" = "openblas" ]; then + + test_impls="openblas" + +elif [ "${impls}" = "vendor" ]; then + + test_impls="vendor" + +elif [ "${impls}" = "eigen" ]; then - test_impls="openblas vendor" + test_impls="eigen" +elif [ "${impls}" = "other" ]; then + + test_impls="openblas vendor eigen" else - test_impls="openblas asm_blis vendor" + test_impls="openblas asm_blis vendor eigen" fi # Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can @@ -138,6 +175,15 @@ for th in ${threads}; do for op in ${test_ops}; do + # Eigen does not support multithreading for hemm, herk, trmm, + # or trsm. So if we're getting ready to execute an Eigen driver + # for one of these operations and nt > 1, we skip this test. + if [ "${im}" = "eigen" ] && \ + [ "${op}" != "gemm" ] && \ + [ "${nt}" != "1" ]; then + continue; + fi + # Find the threading suffix by probing the executable. binname=$(ls ${exec_root}_${dt}${op}_${psize}_${im}_*.x) suf_ext=${binname##*_} @@ -148,13 +194,24 @@ for th in ${threads}; do # Set the number of threads according to th. if [ "${suf}" = "1s" ] || [ "${suf}" = "2s" ]; then - export BLIS_JC_NT=${jc_nt} - export BLIS_PC_NT=${pc_nt} - export BLIS_IC_NT=${ic_nt} - export BLIS_JR_NT=${jr_nt} - export BLIS_IR_NT=${ir_nt} - export OPENBLAS_NUM_THREADS=${nt} - export MKL_NUM_THREADS=${nt} + # Set the threading parameters based on the implementation + # that we are preparing to run. + if [ "${im}" = "asm_blis" ]; then + unset OMP_NUM_THREADS + export BLIS_JC_NT=${jc_nt} + export BLIS_PC_NT=${pc_nt} + export BLIS_IC_NT=${ic_nt} + export BLIS_JR_NT=${jr_nt} + export BLIS_IR_NT=${ir_nt} + elif [ "${im}" = "openblas" ]; then + unset OMP_NUM_THREADS + export OPENBLAS_NUM_THREADS=${nt} + elif [ "${im}" = "eigen" ]; then + export OMP_NUM_THREADS=${nt} + elif [ "${im}" = "vendor" ]; then + unset OMP_NUM_THREADS + export MKL_NUM_THREADS=${nt} + fi export nt_use=${nt} # Multithreaded OpenBLAS seems to have a problem running @@ -173,6 +230,7 @@ for th in ${threads}; do export BLIS_IC_NT=1 export BLIS_JR_NT=1 export BLIS_IR_NT=1 + export OMP_NUM_THREADS=1 export OPENBLAS_NUM_THREADS=1 export MKL_NUM_THREADS=1 export nt_use=1 @@ -185,11 +243,13 @@ for th in ${threads}; do out_file="${out_root}_${suf}_${dt}${op}_${im}.m" #echo "Running (nt = ${nt_use}) ./${exec_name} > ${out_file}" - echo "Running ./${exec_name} > ${out_file}" + echo "Running ${numactl} ./${exec_name} > ${out_file}" - # Run executable. - ./${exec_name} > ${out_file} + # Run executable with or without numactl, depending on how + # the numactl variable was set. + ${numactl} ./${exec_name} > ${out_file} + # Bedtime! sleep ${delay} done diff --git a/test/3/test_gemm.c b/test/3/test_gemm.c index 67c0a845d3..745dae07c4 100644 --- a/test/3/test_gemm.c +++ b/test/3/test_gemm.c @@ -33,7 +33,18 @@ */ #include -#include "blis.h" +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + #include + using namespace Eigen; +#else + #include "blis.h" +#endif + +#define COL_STORAGE +//#define ROW_STORAGE //#define PRINT @@ -69,6 +80,7 @@ int main( int argc, char** argv ) ind = IND; +#if 1 p_begin = P_BEGIN; p_max = P_MAX; p_inc = P_INC; @@ -76,6 +88,15 @@ int main( int argc, char** argv ) m_input = -1; n_input = -1; k_input = -1; +#else + p_begin = 40; + p_max = 1000; + p_inc = 40; + + m_input = -1; + n_input = -1; + k_input = -1; +#endif // Supress compiler warnings about unused variable 'ind'. @@ -87,9 +108,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); @@ -120,13 +138,14 @@ int main( int argc, char** argv ) printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_max; p += p_inc ) + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -139,10 +158,17 @@ int main( int argc, char** argv ) bli_obj_create( dt, 1, 1, 0, 0, &alpha ); bli_obj_create( dt, 1, 1, 0, 0, &beta ); + #ifdef COL_STORAGE bli_obj_create( dt, m, k, 0, 0, &a ); bli_obj_create( dt, k, n, 0, 0, &b ); bli_obj_create( dt, m, n, 0, 0, &c ); bli_obj_create( dt, m, n, 0, 0, &c_save ); + #else + bli_obj_create( dt, m, k, k, 1, &a ); + bli_obj_create( dt, k, n, n, 1, &b ); + bli_obj_create( dt, m, n, n, 1, &c ); + bli_obj_create( dt, m, n, n, 1, &c_save ); + #endif bli_randm( &a ); bli_randm( &b ); @@ -155,12 +181,75 @@ int main( int argc, char** argv ) bli_setsc( (1.0/1.0), 0.0, &beta ); bli_copym( &c, &c_save ); - + #if 0 //def BLIS bli_ind_disable_all_dt( dt ); bli_ind_enable_dt( ind, dt ); #endif +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + #ifdef COL_STORAGE + const int os_a = bli_obj_col_stride( &a ); + const int os_b = bli_obj_col_stride( &b ); + const int os_c = bli_obj_col_stride( &c ); + #else + const int os_a = bli_obj_row_stride( &a ); + const int os_b = bli_obj_row_stride( &b ); + const int os_c = bli_obj_row_stride( &c ); + #endif + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #ifdef COL_STORAGE + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcd_; + #endif + #else + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcd_; + #endif + #endif + #if defined(IS_FLOAT) + Map > A( ( float* )ap, m, k, stride_a ); + Map > B( ( float* )bp, k, n, stride_b ); + Map > C( ( float* )cp, m, n, stride_c ); + #elif defined (IS_DOUBLE) + Map > A( ( double* )ap, m, k, stride_a ); + Map > B( ( double* )bp, k, n, stride_b ); + Map > C( ( double* )cp, m, n, stride_c ); + #elif defined (IS_SCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #elif defined (IS_DCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #endif +#endif + dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) @@ -175,7 +264,7 @@ int main( int argc, char** argv ) bli_printm( "c", &c, "%4.1f", "" ); #endif -#ifdef BLIS +#if defined(BLIS) bli_gemm( &alpha, &a, @@ -183,21 +272,25 @@ int main( int argc, char** argv ) &beta, &c ); -#else +#elif defined(EIGEN) + + C.noalias() += alpha_r * A * B; + +#else // if defined(BLAS) if ( bli_is_float( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); sgemm_( &f77_transa, &f77_transb, @@ -212,17 +305,17 @@ int main( int argc, char** argv ) } else if ( bli_is_double( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); dgemm_( &f77_transa, &f77_transb, @@ -237,17 +330,17 @@ int main( int argc, char** argv ) } else if ( bli_is_scomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); cgemm_( &f77_transa, &f77_transb, @@ -262,17 +355,17 @@ int main( int argc, char** argv ) } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); zgemm_( &f77_transa, &f77_transb, @@ -301,7 +394,7 @@ int main( int argc, char** argv ) printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, ( unsigned long )n, gflops ); diff --git a/test/3/test_hemm.c b/test/3/test_hemm.c index 46cd3708a4..8df46f0f01 100644 --- a/test/3/test_hemm.c +++ b/test/3/test_hemm.c @@ -86,9 +86,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); @@ -119,12 +116,13 @@ int main( int argc, char** argv ) printf( "data_%s_%chemm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_max; p += p_inc ) + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -197,11 +195,11 @@ int main( int argc, char** argv ) f77_int lda = bli_obj_col_stride( &a ); f77_int ldb = bli_obj_col_stride( &b ); f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); ssymm_( &f77_side, &f77_uploa, @@ -220,11 +218,11 @@ int main( int argc, char** argv ) f77_int lda = bli_obj_col_stride( &a ); f77_int ldb = bli_obj_col_stride( &b ); f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); dsymm_( &f77_side, &f77_uploa, @@ -243,11 +241,19 @@ int main( int argc, char** argv ) f77_int lda = bli_obj_col_stride( &a ); f77_int ldb = bli_obj_col_stride( &b ); f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif chemm_( &f77_side, &f77_uploa, @@ -266,11 +272,19 @@ int main( int argc, char** argv ) f77_int lda = bli_obj_col_stride( &a ); f77_int ldb = bli_obj_col_stride( &b ); f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif zhemm_( &f77_side, &f77_uploa, @@ -301,7 +315,7 @@ int main( int argc, char** argv ) printf( "data_%s_%chemm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/3/test_herk.c b/test/3/test_herk.c index 67ebff128e..65dcb9f6cc 100644 --- a/test/3/test_herk.c +++ b/test/3/test_herk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +36,6 @@ #include #include "blis.h" - //#define PRINT int main( int argc, char** argv ) @@ -89,9 +88,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); @@ -122,12 +118,13 @@ int main( int argc, char** argv ) printf( "data_%s_%cherk_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_max; p += p_inc ) + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -190,14 +187,14 @@ int main( int argc, char** argv ) if ( bli_is_float( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); ssyrk_( &f77_uploc, &f77_transa, @@ -210,14 +207,14 @@ int main( int argc, char** argv ) } else if ( bli_is_double( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); dsyrk_( &f77_uploc, &f77_transa, @@ -230,14 +227,21 @@ int main( int argc, char** argv ) } else if ( bli_is_scomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - float* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + float* alphap = ( float* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + float* betap = ( float* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif cherk_( &f77_uploc, &f77_transa, @@ -250,14 +254,21 @@ int main( int argc, char** argv ) } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - double* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + double* alphap = ( double* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + double* betap = ( double* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif zherk_( &f77_uploc, &f77_transa, @@ -284,7 +295,7 @@ int main( int argc, char** argv ) printf( "data_%s_%cherk_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, gflops ); diff --git a/test/3/test_trmm.c b/test/3/test_trmm.c index a4ae5ef9b4..425630a2a8 100644 --- a/test/3/test_trmm.c +++ b/test/3/test_trmm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +36,6 @@ #include #include "blis.h" - //#define PRINT int main( int argc, char** argv ) @@ -92,9 +91,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); @@ -137,12 +133,13 @@ int main( int argc, char** argv ) printf( "data_%s_%ctrmm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_max; p += p_inc ) + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -152,7 +149,7 @@ int main( int argc, char** argv ) bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - if ( bli_does_trans( side ) ) + if ( bli_is_left( side ) ) bli_obj_create( dt, m, m, 0, 0, &a ); else bli_obj_create( dt, n, n, 0, 0, &a ); @@ -207,9 +204,9 @@ int main( int argc, char** argv ) f77_int kk = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); strmm_( &f77_side, &f77_uploa, @@ -227,9 +224,9 @@ int main( int argc, char** argv ) f77_int kk = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); dtrmm_( &f77_side, &f77_uploa, @@ -247,9 +244,15 @@ int main( int argc, char** argv ) f77_int kk = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif ctrmm_( &f77_side, &f77_uploa, @@ -263,13 +266,19 @@ int main( int argc, char** argv ) } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif ztrmm_( &f77_side, &f77_uploa, @@ -300,7 +309,7 @@ int main( int argc, char** argv ) printf( "data_%s_%ctrmm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/3/test_trsm.c b/test/3/test_trsm.c index 88202dec51..678be43308 100644 --- a/test/3/test_trsm.c +++ b/test/3/test_trsm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +36,6 @@ #include #include "blis.h" - //#define PRINT int main( int argc, char** argv ) @@ -92,9 +91,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); @@ -137,12 +133,13 @@ int main( int argc, char** argv ) printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_max; p += p_inc ) + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -152,7 +149,7 @@ int main( int argc, char** argv ) bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - if ( bli_does_trans( side ) ) + if ( bli_is_left( side ) ) bli_obj_create( dt, m, m, 0, 0, &a ); else bli_obj_create( dt, n, n, 0, 0, &a ); @@ -211,9 +208,9 @@ int main( int argc, char** argv ) f77_int kk = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); strsm_( &f77_side, &f77_uploa, @@ -231,9 +228,9 @@ int main( int argc, char** argv ) f77_int kk = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); dtrsm_( &f77_side, &f77_uploa, @@ -251,9 +248,15 @@ int main( int argc, char** argv ) f77_int kk = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif ctrsm_( &f77_side, &f77_uploa, @@ -267,13 +270,19 @@ int main( int argc, char** argv ) } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif ztrsm_( &f77_side, &f77_uploa, @@ -304,7 +313,7 @@ int main( int argc, char** argv ) printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/Makefile b/test/Makefile index 799900b586..361cd2ff8c 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017, Advanced Micro Devices, Inc. +# Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -47,6 +47,7 @@ .PHONY: all \ blis openblas atlas mkl \ + check-env check-env-mk check-lib \ clean cleanx @@ -96,16 +97,11 @@ endif BLAS_LIB_PATH := $(HOME)/flame/lib #MKL_LIB_PATH := /opt/apps/intel/13/composer_xe_2013.2.146/mkl/lib/intel64 #MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 -MKL_LIB_PATH := ${MKLROOT}/lib/intel64 -#ESSL_LIB_PATH := $(HOME)/path/to/essl/changeme +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 # OpenBLAS OPENBLAS_LIB := $(BLAS_LIB_PATH)/libopenblas.a -# ATLAS -ATLAS_LIB := $(BLAS_LIB_PATH)/libf77blas.a \ - $(BLAS_LIB_PATH)/libatlas.a - # MKL MKL_LIB := -L$(MKL_LIB_PATH) \ -lmkl_intel_lp64 \ @@ -113,18 +109,6 @@ MKL_LIB := -L$(MKL_LIB_PATH) \ -lmkl_sequential \ -lpthread -lm -ldl -# ESSL -# Note: ESSL is named differently for SMP and/or BG -ESSL_TYPE := # This is the 32b library on POWER -#ESSL_TYPE := 6464 # This is the 64b library on POWER -#ESSL_TYPE := bg # This is the 32b single-threaded library on Blue Gene -#ESSL_TYPE := smpbg # This is the 32b multi-threaded library on Blue Gene -ESSL_LIB := $(ESSL_LIB_PATH)/libessl$(ESSL_TYPE).a - -# Accelerate -MAC_LIB := -framework Accelerate - - # # --- General build definitions ------------------------------------------------ @@ -150,7 +134,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) @@ -158,117 +142,32 @@ LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # --- Targets/rules ------------------------------------------------------------ # -# Complete list of possible targets when defining 'all': -# -# blis openblas atlas mkl mac essl -# -#all: blis openblas atlas mkl +# Define the operations we will test. +TEST_OPS := dotv axpyv axpbyv\ + gemv ger hemv her her2 trmv trsv \ + gemm gemm3m gemm_batch hemm herk her2k trmm trsm + +# Optionally test gemmt, which some libraries might not implement. +ifeq ($(BUILD_GEMMT),yes) +TEST_OPS := $(TEST_OPS) gemmt +endif + +# Define a function to create the executable names. +test-bins = $(foreach op, $(TEST_OPS), test_$(op)_$(1).x) + +# Create the list of executables for each implementation. +TEST_BINS_BLIS := $(call test-bins,blis) +TEST_BINS_OPENBLAS := $(call test-bins,openblas) +TEST_BINS_MKL := $(call test-bins,mkl) + + all: blis openblas mkl -blis: test_dotv_blis.x \ - test_axpyv_blis.x \ - test_gemv_blis.x \ - test_ger_blis.x \ - test_hemv_blis.x \ - test_her_blis.x \ - test_her2_blis.x \ - test_trmv_blis.x \ - test_trsv_blis.x \ - \ - test_gemm_blis.x \ - test_hemm_blis.x \ - test_herk_blis.x \ - test_her2k_blis.x \ - test_trmm_blis.x \ - test_trsm_blis.x - -openblas: \ - test_dotv_openblas.x \ - test_axpyv_openblas.x \ - test_gemv_openblas.x \ - test_ger_openblas.x \ - test_hemv_openblas.x \ - test_her_openblas.x \ - test_her2_openblas.x \ - test_trmv_openblas.x \ - test_trsv_openblas.x \ - \ - test_gemm_openblas.x \ - test_hemm_openblas.x \ - test_herk_openblas.x \ - test_her2k_openblas.x \ - test_trmm_openblas.x \ - test_trsm_openblas.x - -atlas: \ - test_dotv_atlas.x \ - test_axpyv_atlas.x \ - test_gemv_atlas.x \ - test_ger_atlas.x \ - test_hemv_atlas.x \ - test_her_atlas.x \ - test_her2_atlas.x \ - test_trmv_atlas.x \ - test_trsv_atlas.x \ - \ - test_gemm_atlas.x \ - test_hemm_atlas.x \ - test_herk_atlas.x \ - test_her2k_atlas.x \ - test_trmm_atlas.x \ - test_trsm_atlas.x - -mkl: test_dotv_mkl.x \ - test_axpyv_mkl.x \ - test_gemv_mkl.x \ - test_ger_mkl.x \ - test_hemv_mkl.x \ - test_her_mkl.x \ - test_her2_mkl.x \ - test_trmv_mkl.x \ - test_trsv_mkl.x \ - \ - test_gemm_mkl.x \ - test_hemm_mkl.x \ - test_herk_mkl.x \ - test_her2k_mkl.x \ - test_trmm_mkl.x \ - test_trsm_mkl.x - -essl: test_dotv_essl.x \ - test_axpyv_essl.x \ - test_gemv_essl.x \ - test_ger_essl.x \ - test_hemv_essl.x \ - test_her_essl.x \ - test_her2_essl.x \ - test_trmv_essl.x \ - test_trsv_essl.x \ - \ - test_gemm_essl.x \ - test_hemm_essl.x \ - test_herk_essl.x \ - test_her2k_essl.x \ - test_trmm_essl.x \ - test_trsm_essl.x - -mac: test_dotv_mac.x \ - test_axpyv_mac.x \ - test_gemv_mac.x \ - test_ger_mac.x \ - test_hemv_mac.x \ - test_her_mac.x \ - test_her2_mac.x \ - test_trmv_mac.x \ - test_trsv_mac.x \ - \ - test_gemm_mac.x \ - test_hemm_mac.x \ - test_herk_mac.x \ - test_her2k_mac.x \ - test_trmm_mac.x \ - test_trsm_mac.x +blis: check-env $(TEST_BINS_BLIS) + +openblas: check-env $(TEST_BINS_OPENBLAS) +mkl: check-env $(TEST_BINS_MKL) # --Object file rules -- @@ -276,21 +175,13 @@ mac: test_dotv_mac.x \ $(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c $(CC) $(CFLAGS) -c $< -o $@ + test_%_openblas.o: test_%.c $(CC) $(CFLAGS) -DBLAS=\"openblas\" -c $< -o $@ -test_%_atlas.o: test_%.c - $(CC) $(CFLAGS) -DBLAS=\"atlas\" -c $< -o $@ - test_%_mkl.o: test_%.c $(CC) $(CFLAGS) -DBLAS=\"mkl\" -c $< -o $@ -test_%_essl.o: test_%.c - $(CC) $(CFLAGS) -DBLAS=\"essl\" -c $< -o $@ - -test_%_mac.o: test_%.c - $(CC) $(CFLAGS) -DBLAS=\"mac\" -c $< -o $@ - test_%_blis.o: test_%.c $(CC) $(CFLAGS) -DBLIS -c $< -o $@ @@ -305,22 +196,28 @@ test_%_blis.o: test_%.c test_%_openblas.x: test_%_openblas.o $(LIBBLIS_LINK) $(LINKER) $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ -test_%_atlas.x: test_%_atlas.o $(LIBBLIS_LINK) - $(LINKER) $< $(ATLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - test_%_mkl.x: test_%_mkl.o $(LIBBLIS_LINK) $(LINKER) $< $(MKL_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ -test_%_essl.x: test_%_essl.o $(LIBBLIS_LINK) - $(LINKER) $< $(ESSL_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - -test_%_mac.x: test_%_mac.o $(LIBBLIS_LINK) - $(LINKER) $< $(MAC_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - test_%_blis.x: test_%_blis.o $(LIBBLIS_LINK) $(LINKER) $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@ +# -- Environment check rules -- + +check-env: check-lib + +check-env-mk: +ifeq ($(CONFIG_MK_PRESENT),no) + $(error Cannot proceed: config.mk not detected! Run configure first) +endif + +check-lib: check-env-mk +ifeq ($(wildcard $(LIBBLIS_LINK)),) + $(error Cannot proceed: BLIS library not yet built! Run make first) +endif + + # -- Clean rules -- clean: cleanx diff --git a/test/exec_sizes/Makefile b/test/exec_sizes/Makefile index ca84863539..eefc899186 100644 --- a/test/exec_sizes/Makefile +++ b/test/exec_sizes/Makefile @@ -143,7 +143,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) diff --git a/test/mixeddt/Makefile b/test/mixeddt/Makefile index 7ae4cb9342..20e5378ffb 100644 --- a/test/mixeddt/Makefile +++ b/test/mixeddt/Makefile @@ -127,7 +127,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Which library? @@ -140,11 +140,11 @@ STR_MT := -DTHR_STR=\"mt\" # Problem size specification PDEF_ST := -DP_BEGIN=40 \ - -DP_END=2000 \ + -DP_MAX=2000 \ -DP_INC=40 PDEF_MT := -DP_BEGIN=160 \ - -DP_END=8000 \ + -DP_MAX=8000 \ -DP_INC=160 # Enumerate possible datatypes and computation precisions. diff --git a/test/mixeddt/test_gemm.c b/test/mixeddt/test_gemm.c index ea45a7c141..b6a59a5550 100644 --- a/test/mixeddt/test_gemm.c +++ b/test/mixeddt/test_gemm.c @@ -66,18 +66,18 @@ int main( int argc, char** argv ) num_t dtc = DTC; num_t dtx = DTX; - const bool_t a_real = bli_is_real( dta ); - const bool_t b_real = bli_is_real( dtb ); - const bool_t c_real = bli_is_real( dtc ); - const bool_t a_complex = bli_is_complex( dta ); - const bool_t b_complex = bli_is_complex( dtb ); - const bool_t c_complex = bli_is_complex( dtc ); + const bool a_real = bli_is_real( dta ); + const bool b_real = bli_is_real( dtb ); + const bool c_real = bli_is_real( dtc ); + const bool a_complex = bli_is_complex( dta ); + const bool b_complex = bli_is_complex( dtb ); + const bool c_complex = bli_is_complex( dtc ); // Extract the precision component of the computation datatype. prec_t comp_prec = bli_dt_prec( dtx ); dim_t p_begin = P_BEGIN; - dim_t p_end = P_END; + dim_t p_max = P_MAX; dim_t p_inc = P_INC; int m_input = -1; @@ -122,12 +122,12 @@ int main( int argc, char** argv ) // Begin with initializing the last entry to zero so that // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; //printf( "data_%s_%c%c%c%cgemm_%s", THR_STR, dtc_ch, dta_ch, dtb_ch, dtx_ch, STR ); printf( "data_gemm_%s", STR ); printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, ( unsigned long )0, 0.0 ); @@ -143,7 +143,8 @@ int main( int argc, char** argv ) else if ( c_complex && a_complex && b_complex ) flopsmul = 8.0; - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -220,7 +221,7 @@ int main( int argc, char** argv ) //printf( "data_%s_%c%c%c%cgemm_%s", THR_STR, dtc_ch, dta_ch, dtb_ch, dtx_ch, STR ); printf( "data_gemm_%s", STR ); printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, ( unsigned long )n, gflops ); @@ -270,9 +271,9 @@ void blas_gemm_md( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c ) dom_t comp_dom; obj_t at, bt, ct; obj_t ar, cr; - bool_t needacc; - bool_t force_proj_a = FALSE; - bool_t force_proj_b = FALSE; + bool needacc; + bool force_proj_a = FALSE; + bool force_proj_b = FALSE; diff --git a/test/other/test_copyv.c b/test/other/test_copyv.c new file mode 100644 index 0000000000..9c80eaf736 --- /dev/null +++ b/test/other/test_copyv.c @@ -0,0 +1,218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + + + +//#define BLIS_ACCURACY_TEST +#ifdef BLIS_ACCURACY_TEST + +bool_t scompare_result(int n, float *x, int incx, float *y, int incy) { + for (int i = 0; i < n; i++) { + if ((*x) != (*y)) { + printf("%4f != %4f at location %d\n", *x, *y, i); + return FALSE; + } + x += incx; + y += incy; + } + return TRUE; +} + +bool_t dcompare_result(int n, double *x, int incx, double *y, int incy) { + for (int i = 0; i < n; i++) { + if ((*x) != (*y)) { + printf("%4f != %4f at location %d\n", *x, *y, i); + return FALSE; + } + x += incx; + y += incy; + } + return TRUE; +} +#endif + + +int main(int argc, char** argv) +{ + obj_t x, y; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input, sizeof_dt; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double Gbps; + + //bli_init(); + + n_repeats = 100000; + +#ifndef PRINT + p_begin = 200; + p_end = 100000; + p_inc = 200; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = 16; +#endif + +#if 1 + // dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + if (dt == BLIS_DOUBLE) + sizeof_dt = sizeof(double); + else if (dt == BLIS_FLOAT) + sizeof_dt = sizeof(float); + + printf("executable\t n\t GBs per sec\n"); + for (p = p_begin; p <= p_end; p += p_inc) + { + + if (n_input < 0) n = p * (dim_t)abs(n_input); + else n = (dim_t)n_input; + + bli_obj_create(dt, n, 1, 0, 0, &x); + bli_obj_create(dt, n, 1, 0, 0, &y); + bli_randm(&x); + + + dtime_save = DBL_MAX; + + for (r = 0; r < n_repeats; ++r) + { + dtime = bli_clock(); + +#ifdef BLIS + bli_copyv(&x, + &y + ); +#else + if (bli_is_float(dt)) + { + f77_int nn = bli_obj_length(&x); + f77_int incx = bli_obj_vector_inc(&x); + float* xp = bli_obj_buffer(&x); + f77_int incy = bli_obj_vector_inc(&y); + float* yp = bli_obj_buffer(&y); + + scopy_(&nn, + xp, &incx, + yp, &incy); + + } + else if (bli_is_double(dt)) + { + + f77_int nn = bli_obj_length(&x); + f77_int incx = bli_obj_vector_inc(&x); + double* xp = bli_obj_buffer(&x); + f77_int incy = bli_obj_vector_inc(&y); + double* yp = bli_obj_buffer(&y); + + dcopy_(&nn, + xp, &incx, + yp, &incy + ); + } +#endif + dtime_save = bli_clock_min_diff(dtime_save, dtime); +#ifdef BLIS_ACCURACY_TEST + if (dt == BLIS_FLOAT) { + int nn = bli_obj_length(&x); + int incx = bli_obj_vector_inc(&x); + float* xp = bli_obj_buffer(&x); + int incy = bli_obj_vector_inc(&y); + float* yp = bli_obj_buffer(&y); + if (scompare_result(nn, xp, incx, yp, incy)) + printf("Copy Successful\n"); + else + printf("ALERT!!! Copy Failed\n"); + } + if (dt == BLIS_DOUBLE) { + int nn = bli_obj_length(&x); + int incx = bli_obj_vector_inc(&x); + double* xp = bli_obj_buffer(&x); + int incy = bli_obj_vector_inc(&y); + double* yp = bli_obj_buffer(&y); + if (dcompare_result(nn, xp, incx, yp, incy)) + printf("Copy Successful\n"); + else + printf("ALERT!!! Copy Failed\n"); + } +#endif + } + // Size of the vectors are incrementd by 1000, to test wide range of inputs. + if (p >= 1000) + p_inc = 1000; + + if (p >= 10000) + p_inc = 10000; + Gbps = (n * sizeof_dt) / (dtime_save * 1.0e9); +#ifdef BLIS + printf("data_copyv_blis\t"); +#else + printf("data_copyv_%s\t", BLAS); +#endif + printf("%4lu\t %7.2f\n", + (unsigned long)n, Gbps); + + bli_obj_free(&x); + bli_obj_free(&y); + } + + // bli_finalize(); + + return 0; +} + diff --git a/test/other/test_gemm.c b/test/other/test_gemm.c new file mode 100644 index 0000000000..55e608c480 --- /dev/null +++ b/test/other/test_gemm.c @@ -0,0 +1,392 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define FILE_IN_OUT +//#define PRINT +//#define MATRIX_INITIALISATION +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + dim_t p; + dim_t p_begin, p_end, p_inc; + int m_input, n_input, k_input; + num_t dt; + int r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; +#ifdef FILE_IN_OUT + FILE* fin = NULL; + FILE* fout = NULL; + char gemm = 's'; + +#endif + //bli_init(); + + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 200; + p_end = 2000; + p_inc = 200; + + m_input = -1; + n_input = -1; + k_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + m_input = 5; + k_input = 6; + n_input = 4; +#endif + +#if 1 + //dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + +#ifdef FILE_IN_OUT + if (argc < 3) + { + printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if (fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\t GEMM_Algo\n"); + + + printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\t GEMM_Algo\n"); + + inc_t cs_a; + inc_t cs_b; + inc_t cs_c; + + while (fscanf(fin, "%lld %lld %lld %lld %lld %lld\n", &m, &k, &n, &cs_a, &cs_b, &cs_c) == 6) + { + if ((m > cs_a) || (k > cs_b) || (m > cs_c)) continue; // leading dimension should be greater than number of rows + + bli_obj_create( dt, 1, 1, 0, 0, &alpha); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, k, 1, cs_a, &a ); + bli_obj_create( dt, k, n, 1, cs_b, &b ); + bli_obj_create( dt, m, n, 1, cs_c, &c ); + bli_obj_create( dt, m, n, 1, cs_c, &c_save ); +#ifdef MATRIX_INITIALISATION + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); +#endif + bli_obj_set_conjtrans( transa, &a); + bli_obj_set_conjtrans( transb, &b); + + //bli_setsc( 0.0, -1, &alpha ); + //bli_setsc( 0.0, 1, &beta ); + + bli_setsc( -1, 0.0, &alpha ); + bli_setsc( 1, 0.0, &beta ); + +#else + for ( p = p_begin; p <= p_end; p += p_inc ) + { + if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, k, 0, 0, &a ); + bli_obj_create( dt, k, n, 0, 0, &b ); + bli_obj_create( dt, m, n, 0, 0, &c ); + bli_obj_create( dt, m, n, 0, 0, &c_save ); + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (0.9/1.0), 0.2, &alpha ); + bli_setsc( -(1.1/1.0), 0.3, &beta ); + +#endif + bli_copym( &c, &c_save ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + + dtime = bli_clock(); + + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + +#else + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + sgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + dgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + zgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.1f", "" ); + exit(1); +#endif + + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + +#ifdef BLIS + printf( "data_gemm_blis" ); +#else + printf( "data_gemm_%s", BLAS ); +#endif + + +#ifdef FILE_IN_OUT + + if ( bli_is_double( dt ) ) { + + if (((m * n) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES/4)) || ((m < (BLIS_SMALL_M_RECT_MATRIX_THRES/2) ) && (k < (BLIS_SMALL_K_RECT_MATRIX_THRES/2) ))) + gemm = 'S'; // small gemm + else gemm = 'N'; // Normal blis gemm + + } + else if (bli_is_float( dt )) { + if (((m * n) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) || ((m < BLIS_SMALL_M_RECT_MATRIX_THRES) && (k < BLIS_SMALL_K_RECT_MATRIX_THRES))) + gemm = 'S'; // small gemm + else gemm = 'N'; // normal blis gemm + } + + + + printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f \t %c\n", \ + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops, gemm ); + + + fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f \t %c\n", \ + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops, gemm ); + fflush(fout); + +#else + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, gflops ); +#endif + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); +#ifdef FILE_IN_OUT + fclose(fin); + fclose(fout); +#endif + return 0; +} + diff --git a/test/other/test_scalv.c b/test/other/test_scalv.c new file mode 100644 index 0000000000..22770b2b60 --- /dev/null +++ b/test/other/test_scalv.c @@ -0,0 +1,154 @@ +/* + + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas at Austin nor the names + of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + + +//#define PRINT + +int main(int argc, char** argv) +{ + obj_t a, alpha; + dim_t n, p; + dim_t p_begin, p_end, p_inc; + int n_input; + num_t dt; + int r, n_repeats; + + double dtime; + double dtime_save; + double gflops; + + //bli_init(); + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 100000; + +#ifndef PRINT + p_begin = 200; + p_end = 100000; + p_inc = 200; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = 4; +#endif + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif +#ifdef BLIS + printf( "data_scalv_blis\t n\t gflops\n" ); +#else + printf( "data_scalv_%s\t n\t gflops\n", BLAS ); +#endif + + for (p = p_begin; p <= p_end; p += p_inc) + { + if (n_input < 0) n = p * (dim_t)abs(n_input); + else n = (dim_t)n_input; + + + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_obj_create(dt, 1, n, 0, 0, &a); + + bli_randm(&a); + bli_setsc((2.0), 0.0, &alpha); + dtime_save = DBL_MAX; + + for (r = 0; r < n_repeats; ++r) + { + dtime = bli_clock(); +#ifdef BLIS + bli_scalm(&BLIS_TWO, &a); +#else + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &a ); + f77_int inca = bli_obj_vector_inc( &a ); + float* scalar = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + + sscal_( &nn, scalar, + ap, &inca ); + } + else if ( bli_is_double( dt ) ) + { + f77_int nn = bli_obj_length( &a ); + f77_int inca = bli_obj_vector_inc( &a ); + double* scalar = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + + dscal_( &nn, scalar, + ap, &inca ); + } +#endif + dtime_save = bli_clock_min_diff(dtime_save, dtime); + } +// Size of the vectors are incrementd by 1000, to test wide range of inputs. + if (p == 10000) + p_inc = 10000; + + if (p == 1000) + p_inc = 1000; + + gflops = n / (dtime_save * 1.0e9); +#ifdef BLIS + printf( "data_scalv_blis\t" ); +#else + printf( "data_scalv_%s\t", BLAS ); +#endif + printf(" %4lu\t %7.2f \n", + (unsigned long)n, gflops); + + bli_obj_free(&alpha); + bli_obj_free(&a); + } + return 0; +} + + + diff --git a/test/other/test_swapv.c b/test/other/test_swapv.c new file mode 100644 index 0000000000..1ccebd7439 --- /dev/null +++ b/test/other/test_swapv.c @@ -0,0 +1,185 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +// n x incx y incy +//void dswap_( int*, double*, int*, double*, int* ); +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t x, y; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double gflops; + + bli_init(); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 40; + p_end = 8000; + p_inc = 40; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = -1; +#endif + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; +#ifdef BLIS + printf( "data_swapv_blis" ); +#else + printf( "data_swapv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, 0.0 ); + + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) + { + + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + + bli_obj_create( dt, n, 1, 0, 0, &x ); + bli_obj_create( dt, n, 1, 0, 0, &y ); + + bli_randm( &x ); + bli_randm( &y ); + + dtime_save = 1.0e9; + + for ( r = 0; r < n_repeats; ++r ) + { + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "x", &x, "%4.1f", "" ); + bli_printm( "y", &y, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_swapv( &x, + &y + ); +#else + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); + + sswap_( &nn, + xp, &incx, + yp, &incy ); + + } + else if ( bli_is_double( dt ) ) + { + + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); + + dswap_( &nn, + xp, &incx, + yp, &incy ); + } +#endif + +#ifdef PRINT + bli_printm( "X after", &x, "%4.1f", "" ); + bli_printm( "Y after", &y, "%4.1f", "" ); + + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( n ) / ( dtime_save * 1.0e9 ); + +#ifdef BLIS + printf( "data_swapv_blis" ); +#else + printf( "data_swapv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )n, gflops ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + bli_finalize(); + + return 0; +} diff --git a/test/other/test_trsm.c b/test/other/test_trsm.c new file mode 100644 index 0000000000..31cc3bbb1e --- /dev/null +++ b/test/other/test_trsm.c @@ -0,0 +1,443 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define FILE_IN_OUT +#ifdef FILE_IN_OUT +//#define READ_ALL_PARAMS_FROM_FILE +#endif +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t a, c; + obj_t c_save; + obj_t alpha; + dim_t m, n; + num_t dt; + int r, n_repeats; + side_t side; + uplo_t uploa; + trans_t transa; + diag_t diaga; + f77_char f77_side; + f77_char f77_uploa; + f77_char f77_transa; + f77_char f77_diaga; + + double dtime; + double dtime_save; + double gflops; + +#ifdef FILE_IN_OUT + FILE* fin = NULL; + FILE* fout = NULL; +#else + dim_t p; + dim_t p_begin, p_end, p_inc; + int m_input, n_input; + + //bli_init(); + + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + +#ifndef PRINT + p_begin = 200; + p_end = 2000; + p_inc = 200; + + m_input = -1; + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + m_input = 4; + n_input = 4; +#endif +#endif + + n_repeats = 3; + +#if 1 + //dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + +#ifdef FILE_IN_OUT + if(argc < 3) + { + printf("Usage: ./test_trsm_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if(fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + + fout = fopen(argv[2], "w"); + if(fout == NULL) + { + printf("Error opening the file %s\n", argv[2]); + exit(1); + } + inc_t cs_a; + inc_t cs_b; +#ifdef READ_ALL_PARAMS_FROM_FILE + char side_c, uploa_c, transa_c, diaga_c; + + fprintf(fout, "side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); + + printf("~~~~~~~_BLAS\t side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); + + while(fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) + { + + if( 'l' == side_c|| 'L' == side_c) + side = BLIS_LEFT; + else if('r' == side_c || 'R' == side_c) + side = BLIS_RIGHT; + else + { + printf("Invalid entry for the argument 'side':%c\n",side_c); + continue; + } + + if('l' == uploa_c || 'L' == uploa_c) + uploa = BLIS_LOWER; + else if('u' == uploa_c || 'U' == uploa_c) + uploa = BLIS_UPPER; + else + { + printf("Invalid entry for the argument 'uplo':%c\n",uploa_c); + continue; + } + + if('t' == transa_c || 'T' == transa_c) + transa = BLIS_TRANSPOSE; + else if('n' == transa_c || 'N' == transa_c) + transa = BLIS_NO_TRANSPOSE; + else + { + printf("Invalid entry for the argument 'transa':%c\n",transa_c); + continue; + } + + if('u' == diaga_c || 'U' == diaga_c) + diaga = BLIS_UNIT_DIAG; + else if('n' == diaga_c || 'N' == diaga_c) + diaga = BLIS_NONUNIT_DIAG; + else + { + printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); + continue; + } +#else + + fprintf(fout, "m\t n\t cs_a\t cs_b\t gflops\n"); + + printf("~~~~~~~_BLAS\t m\t n\t cs_a\t cs_b\t gflops\n"); + + while(fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) + { + + side = BLIS_LEFT; + //side = BLIS_RIGHT; + + uploa = BLIS_LOWER; + //uploa = BLIS_UPPER; + + transa = BLIS_NO_TRANSPOSE; + + diaga = BLIS_NONUNIT_DIAG; + + +#endif + + bli_param_map_blis_to_netlib_side( side, &f77_side ); + bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + + if(bli_is_left(side) && ((m > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows + + if(bli_is_right(side) && ((n > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows + + if ( bli_is_left( side ) ) + bli_obj_create( dt, m, m, 1, m, &a ); + else + bli_obj_create( dt, n, n, 1, n, &a ); + bli_obj_create( dt, m, n, 1, m, &c ); + bli_obj_create( dt, m, n, 1, m, &c_save ); + +#else + + for ( p = p_end; p >= p_begin; p -= p_inc ) + { + if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + + + side = BLIS_LEFT; + //side = BLIS_RIGHT; + + uploa = BLIS_LOWER; + //uploa = BLIS_UPPER; + + transa = BLIS_NO_TRANSPOSE; + + diaga = BLIS_NONUNIT_DIAG; + + bli_param_map_blis_to_netlib_side( side, &f77_side ); + bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + + if ( bli_is_left( side ) ) + bli_obj_create( dt, m, m, 0, 0, &a ); + else + bli_obj_create( dt, n, n, 0, 0, &a ); + bli_obj_create( dt, m, n, 0, 0, &c ); + bli_obj_create( dt, m, n, 0, 0, &c_save ); +#endif + + bli_randm( &a ); + bli_randm( &c ); + + bli_obj_set_struc( BLIS_TRIANGULAR, &a ); + bli_obj_set_uplo( uploa, &a ); + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_diag( diaga, &a ); + + // Randomize A and zero the unstored triangle to ensure the + // implementation reads only from the stored region. + bli_randm( &a ); + bli_mktrim( &a ); + + // Load the diagonal of A to make it more likely to be invertible. + bli_shiftd( &BLIS_TWO, &a ); + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_setsc( (2.0/1.0), 1.0, &alpha ); + + + bli_copym( &c, &c_save ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + + dtime = bli_clock(); + + +#ifdef PRINT + bli_invertd( &a ); + bli_printm( "a", &a, "%4.1f", "" ); + bli_invertd( &a ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_trsm( side, + &alpha, + &a, + &c ); +#else + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* cp = bli_obj_buffer( &c ); + + strsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* cp = bli_obj_buffer( &c ); + + dtrsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* cp = bli_obj_buffer( &c ); + + ctrsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* cp = bli_obj_buffer( &c ); + + ztrsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc ); + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%9.5f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + if ( bli_is_left( side ) ) + gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); + else + gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + +#ifdef BLIS + printf( "data_trsm_blis" ); +#else + printf( "data_trsm_%s", BLAS ); +#endif + +#ifdef FILE_IN_OUT +#ifdef READ_ALL_PARAMS_FROM_FILE + + printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n",side_c, uploa_c, transa_c, diaga_c, + (unsigned long )m, (unsigned long ) n, + (unsigned long )cs_a, (unsigned long )cs_b, + gflops); + + fprintf(fout,"%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long )m, (unsigned long ) n, + (unsigned long )cs_a, (unsigned long )cs_b, + gflops); +#else + printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, + (unsigned long )cs_a, (unsigned long )cs_b, + gflops); + fprintf(fout,"%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, + (unsigned long )cs_a, (unsigned long )cs_b, + gflops); +#endif +fflush(fout); + +#else + printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )n, gflops ); +#endif + bli_obj_free( &alpha ); + + bli_obj_free( &a ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + +#ifdef FILE_IN_OUT + fclose(fin); + fclose(fout); +#endif + //bli_finalize(); + + return 0; +} + diff --git a/test/runme.sh b/test/runme.sh index edef984cba..82289cb6a1 100755 --- a/test/runme.sh +++ b/test/runme.sh @@ -9,8 +9,9 @@ l2_ops="gemv ger hemv her her2 trmv trsv" l3_ops="gemm hemm herk her2k trmm trsm" test_ops="${l2_ops} ${l3_ops}" -# Implementations to test -test_impls="openblas atlas mkl blis" +# Implementations to test. +#test_impls="openblas mkl blis" +test_impls="blis" for im in ${test_impls}; do diff --git a/test/studies/skx/Makefile b/test/studies/skx/Makefile index 29134a4ff7..18a82c0ea2 100644 --- a/test/studies/skx/Makefile +++ b/test/studies/skx/Makefile @@ -168,7 +168,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -g -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Datatype diff --git a/test/studies/thunderx2/Makefile b/test/studies/thunderx2/Makefile index c812161d40..ba45ebbe4d 100644 --- a/test/studies/thunderx2/Makefile +++ b/test/studies/thunderx2/Makefile @@ -158,7 +158,7 @@ CFLAGS := $(call get-frame-cflags-for,$(CONFIG_NAME)) CFLAGS += -g -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -lIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Datatype diff --git a/test/sup/Makefile b/test/sup/Makefile new file mode 100644 index 0000000000..670b945498 --- /dev/null +++ b/test/sup/Makefile @@ -0,0 +1,700 @@ +#!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + st mt \ + blissup-st blisconv-st eigen-st openblas-st vendor-st blasfeo-st libxsmm-st \ + blissup-mt blisconv-mt eigen-mt openblas-mt vendor-mt \ + check-env check-env-mk check-lib \ + clean cleanx + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + +# +# --- BLAS and LAPACK implementations ------------------------------------------ +# + +# BLIS library and header path. This is simply wherever it was installed. +#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib +#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis + +# BLIS library. +#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a + +# BLAS library path(s). This is where the BLAS libraries reside. +HOME_LIB_PATH := $(HOME)/flame/lib +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 + +# netlib BLAS +NETLIB_LIB := $(HOME_LIB_PATH)/libblas.a + +# OpenBLAS +OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a +OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a + +# BLASFEO +BLASFEO_LIB := $(HOME_LIB_PATH)/libblasfeo.a + +# libxsmm +LIBXSMM_LIB := $(HOME_LIB_PATH)/libxsmm.a -ldl \ + $(NETLIB_LIB) -lgfortran + +# ATLAS +ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ + $(HOME_LIB_PATH)/libatlas.a + +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + +# MKL +MKL_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_sequential \ + -lpthread -lm -ldl +MKLP_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_gnu_thread \ + -lpthread -lm -ldl -fopenmp + #-L$(ICC_LIB_PATH) \ + #-lgomp + +VENDOR_LIB := $(MKL_LIB) +VENDORP_LIB := $(MKLP_LIB) + + +# +# --- Problem size definitions ------------------------------------------------- +# + +# The problem size range specification is done separately for single-threaded +# and multithreaded execution. Within each threadedness scenario, we allow for +# separate range specifications for cases with: +# - 3L: three large/variable dimensions and no small/constant dimensions +# - 2L: two large/variable dimensions and one small/constant dimension +# - 1L: one large/variable dimension and two small/constant dimensions + +# -- Single-threaded -- + +PS_BEGIN_3L := 2 +PS_MAX_3L := 400 +PS_INC_3L := 2 + +PS_BEGIN_2L := 4 +PS_MAX_2L := 800 +PS_INC_2L := 4 + +PS_BEGIN_1L := 32 +PS_MAX_1L := 6400 +PS_INC_1L := 32 + +# -- Multithreaded -- + +P1_BEGIN_3L := 4 +P1_MAX_3L := 800 +P1_INC_3L := 4 + +P1_BEGIN_2L := 8 +P1_MAX_2L := 1600 +P1_INC_2L := 8 + +P1_BEGIN_1L := 64 +P1_MAX_1L := 12800 +P1_INC_1L := 64 + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c))) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-frame-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +# Use the "framework" CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS. +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) + +# Single or multithreaded string. +STR_ST := -DTHR_STR=\"st\" +STR_MT := -DTHR_STR=\"mt\" + +# Number of trials per problem size. +N_TRIALS := -DN_TRIALS=3 + +# Problem size specification. +PDEF_ST_1L := -DP_BEGIN=$(PS_BEGIN_1L) -DP_MAX=$(PS_MAX_1L) -DP_INC=$(PS_INC_1L) +PDEF_ST_2L := -DP_BEGIN=$(PS_BEGIN_2L) -DP_MAX=$(PS_MAX_2L) -DP_INC=$(PS_INC_2L) +PDEF_ST_3L := -DP_BEGIN=$(PS_BEGIN_3L) -DP_MAX=$(PS_MAX_3L) -DP_INC=$(PS_INC_3L) + +PDEF_MT_1L := -DP_BEGIN=$(P1_BEGIN_1L) -DP_MAX=$(P1_MAX_1L) -DP_INC=$(P1_INC_1L) +PDEF_MT_2L := -DP_BEGIN=$(P1_BEGIN_2L) -DP_MAX=$(P1_MAX_2L) -DP_INC=$(P1_INC_2L) +PDEF_MT_3L := -DP_BEGIN=$(P1_BEGIN_3L) -DP_MAX=$(P1_MAX_3L) -DP_INC=$(P1_INC_3L) + +ifeq ($(E),1) +ERRCHK := -DERROR_CHECK +else +ERRCHK := -DNO_ERROR_CHECK +endif + +# Enumerate possible datatypes and computation precisions. +#dts := s d c z +DTS := s d + +TRANS := n_n \ + n_t \ + t_n \ + t_t + +# While BLIS supports all combinations of row and column storage for matrices +# C, A, and B, the alternatives mostly only support CBLAS APIs, which inherently +# support only "all row-storage" or "all column-storage". Thus, we disable the +# building of those other drivers so that compilation/linking completes sooner. +#STORS := r_r_r \ +# r_r_c \ +# r_c_r \ +# r_c_c \ +# c_r_r \ +# c_r_c \ +# c_c_r \ +# c_c_c +STORS := r_r_r \ + c_c_c + + +SHAPES := l_l_s \ + l_s_l \ + s_l_l \ + s_s_l \ + s_l_s \ + l_s_s \ + l_l_l + +#LDIMS := s l +LDIMS := s + +# Define the small/constant m, n, and k dimensions for single core and multicore +# experiments. +# st, single real +SMS_ST_S := 6 +SNS_ST_S := 16 +SKS_ST_S := 4 +# mt, single real +SMS_MT_S := 6 +SNS_MT_S := 16 +SKS_MT_S := 10 +# st, double real +SMS_ST_D := 6 +SNS_ST_D := 8 +SKS_ST_D := 4 +# mt, double real +SMS_MT_D := 6 +SNS_MT_D := 8 +SKS_MT_D := 10 + + +# +# --- Function definitions ----------------------------------------------------- +# + +# A function to strip the underscores from a list of strings. Note that the +# word 'a_b_c' is transformed into 'abc', NOT 'a b c'. +stripu = $(subst _,,$(1)) + +# Various functions that help us construct the datatype combinations and then +# extract the needed datatype strings and C preprocessor define flags. +get-1of2 = $(word 1,$(subst _, ,$(1))) +get-2of2 = $(word 2,$(subst _, ,$(1))) + +get-1of3 = $(word 1,$(subst _, ,$(1))) +get-2of3 = $(word 2,$(subst _, ,$(1))) +get-3of3 = $(word 3,$(subst _, ,$(1))) + +# A function to return the correct PDEFS variable given the shape string. +# Note that we have different PDEFS for single-threaded and multithreaded. +get-pdefs-st = $(strip $(subst l_l_l,$(PDEF_ST_3L), \ + $(subst l_l_s,$(PDEF_ST_2L), \ + $(subst l_s_l,$(PDEF_ST_2L), \ + $(subst s_l_l,$(PDEF_ST_2L), \ + $(subst s_s_l,$(PDEF_ST_1L), \ + $(subst s_l_s,$(PDEF_ST_1L), \ + $(subst l_s_s,$(PDEF_ST_1L),$(1))))))))) +get-pdefs-mt = $(strip $(subst l_l_l,$(PDEF_MT_3L), \ + $(subst l_l_s,$(PDEF_MT_2L), \ + $(subst l_s_l,$(PDEF_MT_2L), \ + $(subst s_l_l,$(PDEF_MT_2L), \ + $(subst s_s_l,$(PDEF_MT_1L), \ + $(subst s_l_s,$(PDEF_MT_1L), \ + $(subst l_s_s,$(PDEF_MT_1L),$(1))))))))) + +# Datatype defs. +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) + +# Transpose defs. +get-tra-defs-a = $(strip $(subst n,-DTRANSA=BLIS_NO_TRANSPOSE -DA_NOTRANS, \ + $(subst t,-DTRANSA=BLIS_TRANSPOSE -DA_TRANS,$(call get-1of2,$(1))))) +get-tra-defs-b = $(strip $(subst n,-DTRANSB=BLIS_NO_TRANSPOSE -DB_NOTRANS, \ + $(subst t,-DTRANSB=BLIS_TRANSPOSE -DB_TRANS,$(call get-2of2,$(1))))) +get-tra-defs = $(call get-tra-defs-a,$(1)) $(call get-tra-defs-b,$(1)) + +# Storage defs. +get-sto-uch-a = $(strip $(subst r,R, \ + $(subst c,C,$(call get-1of3,$(1))))) +get-sto-uch-b = $(strip $(subst r,R, \ + $(subst c,C,$(call get-2of3,$(1))))) +get-sto-uch-c = $(strip $(subst r,R, \ + $(subst c,C,$(call get-3of3,$(1))))) +get-sto-defs = $(strip \ + -DSTOR3=BLIS_$(call get-sto-uch-a,$(1))$(call get-sto-uch-b,$(1))$(call get-sto-uch-c,$(1)) \ + -DA_STOR_$(call get-sto-uch-a,$(1)) \ + -DB_STOR_$(call get-sto-uch-b,$(1)) \ + -DC_STOR_$(call get-sto-uch-c,$(1))) + +# Dimension defs. +get-shape-defs-cm = $(if $(findstring l,$(1)),-DM_DIM=-1,-DM_DIM=$(2)) +get-shape-defs-cn = $(if $(findstring l,$(1)),-DN_DIM=-1,-DN_DIM=$(2)) +get-shape-defs-ck = $(if $(findstring l,$(1)),-DK_DIM=-1,-DK_DIM=$(2)) +get-shape-defs-m = $(call get-shape-defs-cm,$(call get-1of3,$(1)),$(2)) +get-shape-defs-n = $(call get-shape-defs-cn,$(call get-2of3,$(1)),$(2)) +get-shape-defs-k = $(call get-shape-defs-ck,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-defs = $(strip $(call get-shape-defs-m,$(1),$(2)) \ + $(call get-shape-defs-n,$(1),$(3)) \ + $(call get-shape-defs-k,$(1),$(4))) + +#$(error l_l_s 6 8 4 = $(call get-shape-defs,l_l_s,6,8,4)) + +# Shape-dimension string. +get-shape-str-ch = $(if $(findstring l,$(1)),p,$(2)) +get-shape-str-m = $(call get-shape-str-ch,$(call get-1of3,$(1)),$(2)) +get-shape-str-n = $(call get-shape-str-ch,$(call get-2of3,$(1)),$(2)) +get-shape-str-k = $(call get-shape-str-ch,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-dim-str = m$(call get-shape-str-m,$(1),$(2))n$(call get-shape-str-n,$(1),$(3))k$(call get-shape-str-k,$(1),$(4)) + +# Implementation defs. +# Define a function to return the appropriate -DSTR=, -D[BLIS|BLAS] cpp flags. +get-imp-defs = $(strip $(subst blissup,-DSTR=\"$(1)\" -DBLIS -DSUP, \ + $(subst blisconv,-DSTR=\"$(1)\" -DBLIS, \ + $(subst eigen,-DSTR=\"$(1)\" -DEIGEN, \ + $(subst openblas,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst blasfeo,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst libxsmm,-DSTR=\"$(1)\" -DBLAS -DXSMM, \ + $(subst vendor,-DSTR=\"$(1)\" -DCBLAS,$(1))))))))) + +# Leading dimension defs. +# Define a function to return the appropriate leading dimension cpp flag. +get-ldim-defs = $(strip $(subst s,-DLDIM_SMALL, \ + $(subst l,-DLDIM_LARGE,$(1)))) + +get-ldim-str = ld$(1) + +# Strip the underscores from the list of transposition and storage combinations. +TRANS0 = $(call stripu,$(TRANS)) +STORS0 = $(call stripu,$(STORS)) + +# Limit BLAS and Eigen to only using all row-stored, or all column-stored +# matrices. Also, limit libxsmm to using all column-stored matrices since it +# does not offer CBLAS interfaces. +BSTORS0 = rrr ccc +ESTORS0 = rrr ccc +XSTORS0 = ccc + + +# +# --- Object and binary file definitons ---------------------------------------- +# + +get-sms-st = $(strip $(if $(findstring s,$(1)),$(SMS_ST_S),\ + $(if $(findstring d,$(1)),$(SMS_ST_D),0))) +get-sks-st = $(strip $(if $(findstring s,$(1)),$(SKS_ST_S),\ + $(if $(findstring d,$(1)),$(SKS_ST_D),0))) +get-sns-st = $(strip $(if $(findstring s,$(1)),$(SNS_ST_S),\ + $(if $(findstring d,$(1)),$(SNS_ST_D),0))) +get-sms-mt = $(strip $(if $(findstring s,$(1)),$(SMS_MT_S),\ + $(if $(findstring d,$(1)),$(SMS_MT_D),0))) +get-sks-mt = $(strip $(if $(findstring s,$(1)),$(SKS_MT_S),\ + $(if $(findstring d,$(1)),$(SKS_MT_D),0))) +get-sns-mt = $(strip $(if $(findstring s,$(1)),$(SNS_MT_S),\ + $(if $(findstring d,$(1)),$(SNS_MT_D),0))) + +get-sms = $(strip $(if $(findstring st,$(1)),$(call get-sms-st,$(2)),\ + $(if $(findstring mt,$(1)),$(call get-sms-mt,$(2)),0))) +get-sks = $(strip $(if $(findstring st,$(1)),$(call get-sks-st,$(2)),\ + $(if $(findstring mt,$(1)),$(call get-sks-mt,$(2)),0))) +get-sns = $(strip $(if $(findstring st,$(1)),$(call get-sns-st,$(2)),\ + $(if $(findstring mt,$(1)),$(call get-sns-mt,$(2)),0))) + +# Define a function to generate an object file name using the various parameter +# lists. +get-objs = $(foreach dt,$(1),$(foreach tr,$(2),$(foreach st,$(3),$(foreach sh,$(4), \ + $(foreach sm,$(call get-sms,$(7),$(dt)), \ + $(foreach sn,$(call get-sns,$(7),$(dt)), \ + $(foreach sk,$(call get-sks,$(7),$(dt)), \ + $(foreach ld,$(5),test_$(dt)gemm_$(tr)_$(st)_$(call get-shape-dim-str,$(sh),$(sm),$(sn),$(sk))_$(call get-ldim-str,$(ld))_$(6)_$(7).o)))))))) + +# -- Single-threaded -- + +# Build a list of object files and binaries for each single-threaded +# implementation using the get-st-objs() function defined above. +BLISSUP_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(LDIMS),blissup,st) +BLISSUP_ST_BINS := $(patsubst %.o,%.x,$(BLISSUP_ST_OBJS)) + +BLISCONV_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(LDIMS),blisconv,st) +BLISCONV_ST_BINS := $(patsubst %.o,%.x,$(BLISCONV_ST_OBJS)) + +EIGEN_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(ESTORS0),$(SHAPES),$(LDIMS),eigen,st) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) + +OPENBLAS_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(LDIMS),openblas,st) +OPENBLAS_ST_BINS := $(patsubst %.o,%.x,$(OPENBLAS_ST_OBJS)) + +BLASFEO_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(LDIMS),blasfeo,st) +BLASFEO_ST_BINS := $(patsubst %.o,%.x,$(BLASFEO_ST_OBJS)) + +LIBXSMM_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(XSTORS0),$(SHAPES),$(LDIMS),libxsmm,st) +LIBXSMM_ST_BINS := $(patsubst %.o,%.x,$(LIBXSMM_ST_OBJS)) + +VENDOR_ST_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(LDIMS),vendor,st) +VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLISSUP_ST_OBJS) \ + $(BLISCONV_ST_OBJS) \ + $(EIGEN_ST_OBJS) \ + $(OPENBLAS_ST_OBJS) \ + $(BLASFEO_ST_OBJS) \ + $(LIBXSMM_ST_OBJS) \ + $(VENDOR_ST_OBJS) + +# -- Multithreaded -- + +# Build a list of object files and binaries for each multithreaded +# implementation using the get-st-objs() function defined above. +BLISSUP_MT_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(LDIMS),blissup,mt) +BLISSUP_MT_BINS := $(patsubst %.o,%.x,$(BLISSUP_MT_OBJS)) + +BLISCONV_MT_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(LDIMS),blisconv,mt) +BLISCONV_MT_BINS := $(patsubst %.o,%.x,$(BLISCONV_MT_OBJS)) + +EIGEN_MT_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(ESTORS0),$(SHAPES),$(LDIMS),eigen,mt) +EIGEN_MT_BINS := $(patsubst %.o,%.x,$(EIGEN_MT_OBJS)) + +OPENBLAS_MT_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(LDIMS),openblas,mt) +OPENBLAS_MT_BINS := $(patsubst %.o,%.x,$(OPENBLAS_MT_OBJS)) + +VENDOR_MT_OBJS := $(call get-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(LDIMS),vendor,mt) +VENDOR_MT_BINS := $(patsubst %.o,%.x,$(VENDOR_MT_OBJS)) + +#$(error "objs = $(EIGEN_ST_BINS)" ) + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLISSUP_MT_OBJS) \ + $(BLISCONV_MT_OBJS) \ + $(EIGEN_MT_OBJS) \ + $(OPENBLAS_MT_OBJS) \ + $(VENDOR_MT_OBJS) + + +# +# --- High-level targets/rules ------------------------------------------------- +# + +all: st + +# -- Single-threaded -- + +st: blissup-st blisconv-st \ + eigen-st openblas-st blasfeo-st libxsmm-st vendor-st + +blissup-st: check-env $(BLISSUP_ST_BINS) +blisconv-st: check-env $(BLISCONV_ST_BINS) +eigen-st: check-env $(EIGEN_ST_BINS) +openblas-st: check-env $(OPENBLAS_ST_BINS) +blasfeo-st: check-env $(BLASFEO_ST_BINS) +libxsmm-st: check-env $(LIBXSMM_ST_BINS) +vendor-st: check-env $(VENDOR_ST_BINS) + +# -- Multithreaded -- + +mt: blissup-mt blisconv-mt \ + eigen-mt openblas-mt vendor-mt + +blissup-mt: check-env $(BLISSUP_MT_BINS) +blisconv-mt: check-env $(BLISCONV_MT_BINS) +eigen-mt: check-env $(EIGEN_MT_BINS) +openblas-mt: check-env $(OPENBLAS_MT_BINS) +vendor-mt: check-env $(VENDOR_MT_BINS) + + +# --- Object file rules -------------------------------------------------------- + +# Define the implementations for which we will instantiate compilation rules. +BIMPLS_ST := blissup blisconv openblas blasfeo libxsmm vendor +BIMPLS_MT := blissup blisconv openblas vendor +EIMPLS := eigen + +# -- Single-threaded BLAS -- + +# 1 2 3 4 567 8 9 +# test_dgemm_nn_rrr_mpn6kp_lds_blissup_st.x + +# Define the function that will be used to instantiate compilation rules +# for the various single-threaded implementations. +define make-st-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(call get-ldim-str,$(8))_$(9)_st.o: test_gemm.c Makefile + $(CC) $(CFLAGS) $(ERRCHK) $(N_TRIALS) $(call get-pdefs-st,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-ldim-defs,$(8)) $(call get-imp-defs,$(9)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each BLIS/BLAS/CBLAS +# implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(call get-sms,st,$(dt)), \ +$(foreach sn,$(call get-sns,st,$(dt)), \ +$(foreach sk,$(call get-sks,st,$(dt)), \ +$(foreach ld,$(LDIMS), \ +$(foreach impl,$(BIMPLS_ST), \ +$(eval $(call make-st-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(ld),$(impl)))))))))))) + +# -- Multithreaded BLAS -- + +# Define the function that will be used to instantiate compilation rules +# for the various multithreaded implementations. +define make-mt-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(call get-ldim-str,$(8))_$(9)_mt.o: test_gemm.c Makefile + $(CC) $(CFLAGS) $(ERRCHK) $(N_TRIALS) $(call get-pdefs-mt,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-ldim-defs,$(8)) $(call get-imp-defs,$(9)) $(STR_MT) -c $$< -o $$@ +endef + +# Instantiate the rule function make-mt-rule() for each BLIS/BLAS/CBLAS +# implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(call get-sms,mt,$(dt)), \ +$(foreach sn,$(call get-sns,mt,$(dt)), \ +$(foreach sk,$(call get-sks,mt,$(dt)), \ +$(foreach ld,$(LDIMS), \ +$(foreach impl,$(BIMPLS_MT), \ +$(eval $(call make-mt-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(ld),$(impl)))))))))))) + +# -- Single-threaded Eigen -- + +# Define the function that will be used to instantiate compilation rules +# for the single-threaded Eigen implementation. +define make-eigst-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(call get-ldim-str,$(8))_$(9)_st.o: test_gemm.c Makefile + $(CXX) $(CXXFLAGS_ST) $(ERRCHK) $(N_TRIALS) $(call get-pdefs-st,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-ldim-defs,$(8)) $(call get-imp-defs,$(9)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each Eigen implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(call get-sms,st,$(dt)), \ +$(foreach sn,$(call get-sns,st,$(dt)), \ +$(foreach sk,$(call get-sks,st,$(dt)), \ +$(foreach ld,$(LDIMS), \ +$(foreach impl,$(EIMPLS), \ +$(eval $(call make-eigst-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(ld),$(impl)))))))))))) + +# -- Multithreaded Eigen -- + +# Define the function that will be used to instantiate compilation rules +# for the multithreaded Eigen implementation. +define make-eigmt-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(call get-ldim-str,$(8))_$(9)_mt.o: test_gemm.c Makefile + $(CXX) $(CXXFLAGS_MT) $(ERRCHK) $(N_TRIALS) $(call get-pdefs-mt,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-ldim-defs,$(8)) $(call get-imp-defs,$(9)) $(STR_MT) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each Eigen implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(call get-sms,mt,$(dt)), \ +$(foreach sn,$(call get-sns,mt,$(dt)), \ +$(foreach sk,$(call get-sks,mt,$(dt)), \ +$(foreach ld,$(LDIMS), \ +$(foreach impl,$(EIMPLS), \ +$(eval $(call make-eigmt-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(ld),$(impl)))))))))))) + + +# --- Executable file rules ---------------------------------------------------- + +# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS +# on the link command line in case BLIS was configured with the BLAS +# compatibility layer. This prevents BLIS from inadvertently getting called +# for the BLAS routines we are trying to test with. + +# -- Single-threaded -- + +test_%_blissup_st.x: test_%_blissup_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blisconv_st.x: test_%_blisconv_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_eigen_st.x: test_%_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_openblas_st.x: test_%_openblas_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blasfeo_st.x: test_%_blasfeo_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(BLASFEO_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_libxsmm_st.x: test_%_libxsmm_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBXSMM_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_vendor_st.x: test_%_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +# -- Multithreaded -- + +test_%_blissup_mt.x: test_%_blissup_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blisconv_mt.x: test_%_blisconv_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_eigen_mt.x: test_%_eigen_mt.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_openblas_mt.x: test_%_openblas_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_vendor_mt.x: test_%_vendor_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Environment check rules -- + +check-env: check-lib + +check-env-mk: +ifeq ($(CONFIG_MK_PRESENT),no) + $(error Cannot proceed: config.mk not detected! Run configure first) +endif + +check-lib: check-env-mk +ifeq ($(wildcard $(LIBBLIS_LINK)),) + $(error Cannot proceed: BLIS library not yet built! Run make first) +endif + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.x *.o + diff --git a/test/sup/octave/bkup/gen_opsupnames.m b/test/sup/octave/bkup/gen_opsupnames.m new file mode 100644 index 0000000000..40258677d0 --- /dev/null +++ b/test/sup/octave/bkup/gen_opsupnames.m @@ -0,0 +1,50 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims, ldim, pack ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + % sprintf'ing directly into an array of strings, as in: + % + % opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp_%s_%s', ... ); + % + % doesn't work when the string lengths as they would if any of the constant + % dimensions is greater than 9. + str0 = sprintf( '%s_%s_m%dnpkp_%s_%s', op, stor, smallm, ldim, pack ); + str1 = sprintf( '%s_%s_mpn%dkp_%s_%s', op, stor, smalln, ldim, pack ); + str2 = sprintf( '%s_%s_mpnpk%d_%s_%s', op, stor, smallk, ldim, pack ); + str3 = sprintf( '%s_%s_mpn%dk%d_%s_%s', op, stor, smalln, smallk, ldim, pack ); + str4 = sprintf( '%s_%s_m%dnpk%d_%s_%s', op, stor, smallm, smallk, ldim, pack ); + str5 = sprintf( '%s_%s_m%dn%dkp_%s_%s', op, stor, smallm, smalln, ldim, pack ); + str6 = sprintf( '%s_%s_mpnpkp_%s_%s', op, stor, ldim, pack ); + + opsupnames( i+0, : ) = sprintf( '%-31s', str0 ); + opsupnames( i+1, : ) = sprintf( '%-31s', str1 ); + opsupnames( i+2, : ) = sprintf( '%-31s', str2 ); + opsupnames( i+3, : ) = sprintf( '%-31s', str3 ); + opsupnames( i+4, : ) = sprintf( '%-31s', str4 ); + opsupnames( i+5, : ) = sprintf( '%-31s', str5 ); + opsupnames( i+6, : ) = sprintf( '%-31s', str6 ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + diff --git a/test/sup/octave/bkup/plot_l3sup_perf.m b/test/sup/octave/bkup/plot_l3sup_perf.m new file mode 100644 index 0000000000..a54541f37a --- /dev/null +++ b/test/sup/octave/bkup/plot_l3sup_perf.m @@ -0,0 +1,316 @@ +function r_val = plot_l3sup_perf( opname, ... + smalldims, ... + data_blissup, ... + data_blisconv, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + data_bfeo, ... + data_xsmm, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl ) + +% Define a single-/multithreadedness predicate for convenience. +if nth == 1 + is_st = 1; +else + is_st = 0; +end + +% Define the column in which the performance rates are found. +flopscol = size( data_blissup, 2 ); + +% Check if blasfeo data is available. +has_bfeo = 1; +if data_bfeo( 1, flopscol ) == 0.0 + has_bfeo = 0; +end + +% Check if libxsmm data is available. +has_xsmm = 1; +if data_xsmm( 1, flopscol ) == 0.0 + has_xsmm = 0; +end + +% Define which plot id will have the legend. +%if is_st +% if has_xsmm == 1 +% legend_plot_id = 2*cols + 1*5; +% else +% legend_plot_id = 1*cols + 1*5; +% end +%else +% legend_plot_id = 0*cols + 1*6; +%end +legend_plot_id = cols*rows; + +% Hold the axes. +if 1 + ax1 = subplot( rows, cols, theid ); + hold( ax1, 'on' ); +end + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blisconv = 'k'; lines_blisconv = ':'; markr_blisconv = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; +color_bfeo = 'c'; lines_bfeo = '-'; markr_bfeo = 'o'; +color_xsmm = 'g'; lines_xsmm = '-'; markr_xsmm = 'o'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_lg = sprintf( 'BLIS sup' ); +blisconv_lg = sprintf( 'BLIS conv' ); +eigen_lg = sprintf( 'Eigen' ); +open_lg = sprintf( 'OpenBLAS' ); +vend_lg = vend_str; +bfeo_lg = sprintf( 'BLASFEO' ); +xsmm_lg = sprintf( 'libxsmm' ); + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + + +%flopscol = 4; +msize = 5; +if 1 + fontsize = 12; +else + fontsize = 16; +end +linesize = 0.5; +legend_loc = 'southeast'; + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that we +% only use half the data points for the m = n = k column of graphs. +%if mod(theid-1,cols) == 6 +% np = size( data_blissup, 1 ) / 2; +%else +% np = size( data_blissup, 1 ); +%end +np = size( data_blissup, 1 ); + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blisconv_ln = line( x_axis( 1:np, 1 ), data_blisconv( 1:np, flopscol ) / nth, ... + 'Color',color_blisconv, 'LineStyle',lines_blisconv, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +if has_bfeo == 1 + bfeo_ln = line( x_axis( 1:np, 1 ), data_bfeo( 1:np, flopscol ) / nth, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +else + bfeo_ln = line( nan, nan, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +end +if has_xsmm == 1 + xsmm_ln = line( x_axis( 1:np, 1 ), data_xsmm( 1:np, flopscol ) / nth, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +else + xsmm_ln = line( nan, nan, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if 10000 <= x_end && x_end < 15000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 3000 6000 9000 12000 ] ); +elseif 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 2000 4000 6000 8000 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) +if nth == 1 && theid == legend_plot_id + if has_xsmm == 1 + % single-threaded, with libxsmm (ccc) + leg = legend( ... + [ blissup_ln blisconv_ln eigen_ln open_ln vend_ln bfeo_ln xsmm_ln ], ... + blissup_lg, blisconv_lg, eigen_lg, open_lg, vend_lg, bfeo_lg, xsmm_lg, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[15.35 4.62 1.9 1.20] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-3 ); + set( leg,'Position',[18.20 10.20 1.15 0.7 ] ); % (1,4tl) + end + else + % single-threaded, without libxsmm (rrr, or other) + leg = legend( ... + [ blissup_ln blisconv_ln eigen_ln open_ln vend_ln bfeo_ln ], ... + blissup_lg, blisconv_lg, eigen_lg, open_lg, vend_lg, bfeo_lg, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[15.35 7.40 1.9 1.10] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7] ); % (1,4tl) + end + end +elseif nth > 1 && theid == legend_plot_id + % multithreaded + leg = legend( ... + [ blissup_ln blisconv_ln eigen_ln open_ln vend_ln ], ... + blissup_lg, blisconv_lg, eigen_lg, open_lg, vend_lg, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[18.20 10.30 1.9 0.95] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7] ); % (1,4tl) + end +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +% The default is to align the plot title across whole figure, not the box. +% This is a hack to nudge the title back to the center of the box. +if impl == 'octave' + tpos = get( titl, 'Position' ); + % For some reason, the titles in the graphs in the last column start + % off in a different relative position than the graphs in the other + % columns. Here, we manually account for that. + if mod(theid-1,cols) == 6 + tpos(1) = tpos(1) + -10; + else + tpos(1) = tpos(1) + -40; + end + set( titl, 'Position', tpos ); + set( titl, 'FontSize', fontsize ); +else % impl == 'matlab' + tpos = get( titl, 'Position' ); + tpos(1) = tpos(1) + 90; + set( titl, 'Position', tpos ); +end + +sll_str = sprintf( 'm = %u; n = k', smalldims(1) ); +lsl_str = sprintf( 'n = %u; m = k', smalldims(2) ); +lls_str = sprintf( 'k = %u; m = n', smalldims(3) ); +lss_str = sprintf( 'm; n = %u, k = %u', smalldims(2), smalldims(3) ); +sls_str = sprintf( 'n; m = %u, k = %u', smalldims(1), smalldims(3) ); +ssl_str = sprintf( 'k; m = %u, n = %u', smalldims(1), smalldims(2) ); +lll_str = sprintf( 'm = n = k' ); + +% Place labels on the bottom row of graphs. +if theid > (rows-1)*cols + %xlab = xlabel( ax1,xaxisname ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, sll_str ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, lsl_str ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, lls_str ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, lss_str ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, sls_str ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, ssl_str ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, lll_str ); + end +end + +% Place labels on the left-hand column of graphs. +if mod(theid-1,cols) == 0 + ylab = ylabel( ax1,yaxisname ); +end + +r_val = 0; + diff --git a/test/sup/octave/bkup/plot_panel_trxsh.m b/test/sup/octave/bkup/plot_panel_trxsh.m new file mode 100644 index 0000000000..ad2aab7f9d --- /dev/null +++ b/test/sup/octave/bkup/plot_panel_trxsh.m @@ -0,0 +1,187 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + ldim_str, ... + pack_str, ... + dirpath, ... + arch_str, ... + vend_str, ... + impl ... + ) + +if 1 == 1 + %fig = figure('Position', [100, 100, 2400, 1500]); + fig = figure('Position', [100, 100, 2400, 1200]); + orient( fig, 'portrait' ); + set(gcf,'PaperUnits', 'inches'); + if impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); + else % impl == 'octave' % octave 4.x + set(gcf,'PaperSize', [12 22.0]); + set(gcf,'PaperPositionMode','auto'); + end + set(gcf,'PaperOrientation','landscape'); +end + +%cfreq = 1.8; +%dflopspercycle = 32; + +if nth == 1 + is_st = 1; +else + is_st = 0; +end + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blisconv = '%s/output_%s_%s_blisconv.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; +filetemp_bfeo = '%s/output_%s_%s_blasfeo.m'; +filetemp_xsmm = '%s/output_%s_%s_libxsmm.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims, ldim_str, pack_str ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + % Remove leading and trailing whitespace. + opsupname = strtrim( opsupname ); + opname = strtrim( opname ); + + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Construct filenames for the data files from templates. + file_blissup = sprintf( filetemp_blissup, dirpath, thr_str, opsupname ); + file_blisconv = sprintf( filetemp_blisconv, dirpath, thr_str, opsupname ); + file_eigen = sprintf( filetemp_eigen, dirpath, thr_str, opsupname ); + file_open = sprintf( filetemp_open, dirpath, thr_str, opsupname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opsupname ); + file_bfeo = sprintf( filetemp_bfeo, dirpath, thr_str, opsupname ); + + % Load the data files. + %str = sprintf( ' Loading %s', file_blissup ); disp(str); + run( file_blissup ) + run( file_blisconv ) + run( file_eigen ) + run( file_open ) + run( file_vend ) + if is_st + run( file_bfeo ) + end + + % Construct variable names for the variables in the data files. + var_blissup = sprintf( vartemp, thr_str, opname, 'blissup' ); + var_blisconv = sprintf( vartemp, thr_str, opname, 'blisconv' ); + var_eigen = sprintf( vartemp, thr_str, opname, 'eigen' ); + var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); + var_bfeo = sprintf( vartemp, thr_str, opname, 'blasfeo' ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data_blissup = eval( var_blissup ); % e.g. data_st_dgemm_blissup( :, : ); + data_blisconv = eval( var_blisconv ); % e.g. data_st_dgemm_blisconv( :, : ); + data_eigen = eval( var_eigen ); % e.g. data_st_dgemm_eigen( :, : ); + data_open = eval( var_open ); % e.g. data_st_dgemm_openblas( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_dgemm_vendor( :, : ); + if is_st + data_bfeo = eval( var_bfeo ); % e.g. data_st_dgemm_blasfeo( :, : ); + else + % Set the data variable to zeros using the same dimensions as the other + % variables. + data_bfeo = zeros( size( data_blissup, 1 ), ... + size( data_blissup, 2 ) ); + end + + if is_st && stor_str == 'ccc' + % Only read xsmm data for the column storage case, since that's the + % only format that libxsmm supports. + file_xsmm = sprintf( filetemp_xsmm, dirpath, thr_str, opsupname ); + run( file_xsmm ) + var_xsmm = sprintf( vartemp, thr_str, opname, 'libxsmm' ); + data_xsmm = eval( var_xsmm ); % e.g. data_st_dgemm_libxsmm( :, : ); + else + % Set the data variable to zeros using the same dimensions as the other + % variables. + data_xsmm = zeros( size( data_blissup, 1 ), ... + size( data_blissup, 2 ) ); + end + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + if 1 == 1 + plot_l3sup_perf( opsupname, ... + smalldims, ... + data_blissup, ... + data_blisconv, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + data_bfeo, ... + data_xsmm, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl ); + + clear data_st_*gemm_*; + clear data_mt_*gemm_*; + clear data_blissup; + clear data_blisconv; + clear data_eigen; + clear data_open; + clear data_vend; + clear data_bfeo; + clear data_xsmm; + + end + +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +%print(gcf, 'gemm_md','-fillpage','-dpdf'); +%print(gcf, outfile,'-bestfit','-dpdf'); +if impl == 'octave' + print(gcf, outfile); +else % if impl == 'matlab' + print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/octave/bkup/runthese.m b/test/sup/octave/bkup/runthese.m new file mode 100644 index 0000000000..8e3519f33b --- /dev/null +++ b/test/sup/octave/bkup/runthese.m @@ -0,0 +1,8 @@ +% kabylake +plot_panel_trxsh(3.80,16,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/kabylake/20200302/mnkt100000_st','kbl','MKL','octave'); close; clear all; + +% haswell +plot_panel_trxsh(3.5,16,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/haswell/20200302/mnkt100000_st','has','MKL','octave'); close; clear all; + +% epyc +plot_panel_trxsh(3.00, 8,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/epyc/20200302/mnkt100000_st','epyc','MKL','octave'); close; clear all; diff --git a/test/sup/octave/gen_opsupnames.m b/test/sup/octave/gen_opsupnames.m new file mode 100644 index 0000000000..40258677d0 --- /dev/null +++ b/test/sup/octave/gen_opsupnames.m @@ -0,0 +1,50 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims, ldim, pack ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + % sprintf'ing directly into an array of strings, as in: + % + % opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp_%s_%s', ... ); + % + % doesn't work when the string lengths as they would if any of the constant + % dimensions is greater than 9. + str0 = sprintf( '%s_%s_m%dnpkp_%s_%s', op, stor, smallm, ldim, pack ); + str1 = sprintf( '%s_%s_mpn%dkp_%s_%s', op, stor, smalln, ldim, pack ); + str2 = sprintf( '%s_%s_mpnpk%d_%s_%s', op, stor, smallk, ldim, pack ); + str3 = sprintf( '%s_%s_mpn%dk%d_%s_%s', op, stor, smalln, smallk, ldim, pack ); + str4 = sprintf( '%s_%s_m%dnpk%d_%s_%s', op, stor, smallm, smallk, ldim, pack ); + str5 = sprintf( '%s_%s_m%dn%dkp_%s_%s', op, stor, smallm, smalln, ldim, pack ); + str6 = sprintf( '%s_%s_mpnpkp_%s_%s', op, stor, ldim, pack ); + + opsupnames( i+0, : ) = sprintf( '%-31s', str0 ); + opsupnames( i+1, : ) = sprintf( '%-31s', str1 ); + opsupnames( i+2, : ) = sprintf( '%-31s', str2 ); + opsupnames( i+3, : ) = sprintf( '%-31s', str3 ); + opsupnames( i+4, : ) = sprintf( '%-31s', str4 ); + opsupnames( i+5, : ) = sprintf( '%-31s', str5 ); + opsupnames( i+6, : ) = sprintf( '%-31s', str6 ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + diff --git a/test/sup/octave/load_data.m b/test/sup/octave/load_data.m new file mode 100644 index 0000000000..0044bc5605 --- /dev/null +++ b/test/sup/octave/load_data.m @@ -0,0 +1,17 @@ +function [ r_val ] = load_data( ... + filetemp, ... + dirpath, ... + thr_str, ... + opsupname, ... + vartemp, ... + opname, ... + impl_str ... + ) + +filepath = sprintf( filetemp, dirpath, thr_str, opsupname ); +run( filepath ) +varname = sprintf( vartemp, thr_str, opname, impl_str ); +data = eval( varname ); % e.g. data_st_dgemm_blissup( :, : ); + +r_val = data; + diff --git a/test/sup/octave/plot_l3sup_perf.m b/test/sup/octave/plot_l3sup_perf.m new file mode 100644 index 0000000000..b4cbb06947 --- /dev/null +++ b/test/sup/octave/plot_l3sup_perf.m @@ -0,0 +1,291 @@ +function r_val = plot_l3sup_perf( opname, ... + smalldims, ... + data_blissup, ... + data_blisconv, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + data_bfeo, ... + data_xsmm, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl, ... + fontsize, ... + leg_pos_st, leg_pos_st_x, leg_pos_mt, ... + sp_margins ) + +% Define the column in which the performance rates are found. +flopscol = size( data_blissup, 2 ); + +% Check if blasfeo data is available. +has_bfeo = 1; +if data_bfeo( 1, flopscol ) == 0.0 + has_bfeo = 0; +end + +% Check if libxsmm data is available. +has_xsmm = 1; +if data_xsmm( 1, flopscol ) == 0.0 + has_xsmm = 0; +end + +% Define which plot id will have the legend. +% NOTE: We can draw the legend on any graph as long as it has already been +% rendered. Since the coordinates are global, we can simply always wait until +% the final graph to draw the legend. +legend_plot_id = cols*rows; + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blisconv = 'k'; lines_blisconv = ':'; markr_blisconv = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; +color_bfeo = 'c'; lines_bfeo = '-'; markr_bfeo = 'o'; +color_xsmm = 'g'; lines_xsmm = '-'; markr_xsmm = 'o'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_lg = sprintf( 'BLIS sup' ); +blisconv_lg = sprintf( 'BLIS conv' ); +eigen_lg = sprintf( 'Eigen' ); +open_lg = sprintf( 'OpenBLAS' ); +vend_lg = vend_str; +bfeo_lg = sprintf( 'BLASFEO' ); +xsmm_lg = sprintf( 'libxsmm' ); + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + +% Set the marker size, line size, and other items. +msize = 5; +linesize = 0.8; +legend_loc = 'southeast'; + +%ax1 = subplot( rows, cols, theid ); +ax1 = subplot_tight( rows, cols, theid, sp_margins ); + +% Hold the axes. +hold( ax1, 'on' ); + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that we +% only use half the data points for the m = n = k column of graphs. +%if mod(theid-1,cols) == 6 +% np = size( data_blissup, 1 ) / 2; +%else +% np = size( data_blissup, 1 ); +%end +np = size( data_blissup, 1 ); + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blisconv_ln = line( x_axis( 1:np, 1 ), data_blisconv( 1:np, flopscol ) / nth, ... + 'Color',color_blisconv, 'LineStyle',lines_blisconv, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +if has_bfeo == 1 + bfeo_ln = line( x_axis( 1:np, 1 ), data_bfeo( 1:np, flopscol ) / nth, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +else + bfeo_ln = line( nan, nan, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +end +if has_xsmm == 1 + xsmm_ln = line( x_axis( 1:np, 1 ), data_xsmm( 1:np, flopscol ) / nth, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +else + xsmm_ln = line( nan, nan, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if 10000 <= x_end && x_end < 15000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 3000 6000 9000 12000 ] ); +elseif 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 2000 4000 6000 8000 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) + +if nth == 1 && theid == legend_plot_id + if has_xsmm == 1 + % single-threaded, with libxsmm (ccc) + leg = legend( ... + [ blissup_ln blisconv_ln eigen_ln open_ln vend_ln bfeo_ln xsmm_ln ], ... + blissup_lg, blisconv_lg, eigen_lg, open_lg, vend_lg, bfeo_lg, xsmm_lg, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches' ); + set( leg,'FontSize',fontsize ); + %set( leg,'Position',[15.35 4.62 1.9 1.20] ); + set( leg,'Position',leg_pos_st_x ); + else + % single-threaded, without libxsmm (rrr, or other) + leg = legend( ... + [ blissup_ln blisconv_ln eigen_ln open_ln vend_ln bfeo_ln ], ... + blissup_lg, blisconv_lg, eigen_lg, open_lg, vend_lg, bfeo_lg, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches' ); + set( leg,'FontSize',fontsize ); + %set( leg,'Position',[15.35 7.40 1.9 1.10] ); + set( leg,'Position',leg_pos_st ); + end +elseif nth > 1 && theid == legend_plot_id + % multithreaded + leg = legend( ... + [ blissup_ln blisconv_ln eigen_ln open_ln vend_ln ], ... + blissup_lg, blisconv_lg, eigen_lg, open_lg, vend_lg, ... + 'Location', legend_loc ); + set( leg,'Box','off','Color','none','Units','inches' ); + set( leg,'FontSize',fontsize ); + %set( leg,'Position',[18.20 10.30 1.9 0.95] ); + set( leg,'Position',leg_pos_mt ); +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +% The default is to align the plot title across whole figure, not the box. +% This is a hack to nudge the title back to the center of the box. +if impl == 'octave' + tpos = get( titl, 'Position' ); + % For some reason, the titles in the graphs in certain columns start + % off in a different relative position. Here, we manually fix that. + %modid = mod(theid-1,cols); + %if modid == 0 || modid == 1 || modid == 2 + % tpos(1) = tpos(1) + 0; + %elseif modid == 3 || modid == 4 || modid == 5 + % tpos(1) = tpos(1) + 0; + %else + % tpos(1) = tpos(1) + 0; + %end + set( titl, 'Position', tpos ); + set( titl, 'FontSize', fontsize-1 ); +else % impl == 'matlab' + tpos = get( titl, 'Position' ); + tpos(1) = tpos(1) + 90; + set( titl, 'Position', tpos ); +end + +sll_str = sprintf( 'm = %u; n = k', smalldims(1) ); +lsl_str = sprintf( 'n = %u; m = k', smalldims(2) ); +lls_str = sprintf( 'k = %u; m = n', smalldims(3) ); +lss_str = sprintf( 'm; n = %u, k = %u', smalldims(2), smalldims(3) ); +sls_str = sprintf( 'n; m = %u, k = %u', smalldims(1), smalldims(3) ); +ssl_str = sprintf( 'k; m = %u, n = %u', smalldims(1), smalldims(2) ); +lll_str = sprintf( 'm = n = k' ); + +% Place labels on the bottom row of graphs. +if theid > (rows-1)*cols + %xlab = xlabel( ax1,xaxisname ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, sll_str ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, lsl_str ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, lls_str ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, lss_str ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, sls_str ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, ssl_str ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, lll_str ); + end +end + +% Place labels on the left-hand column of graphs. +if mod(theid-1,cols) == 0 + ylab = ylabel( ax1,yaxisname ); +end + +r_val = 0; + diff --git a/test/sup/octave/plot_panel_trxsh.m b/test/sup/octave/plot_panel_trxsh.m new file mode 100644 index 0000000000..b8af81466e --- /dev/null +++ b/test/sup/octave/plot_panel_trxsh.m @@ -0,0 +1,148 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + ldim_str, ... + pack_str, ... + dirpath, ... + arch_str, ... + vend_str ... + ) + +impl = 'octave'; + +%subp = 'default'; +subp = 'tight'; + +if strcmp( subp, 'default' ) + position = [100 100 2400 1200]; + papersize = [12 22.0]; + sp_margins = [ 0.070 0.049 ]; +else + position = [100 100 2308 1202]; + papersize = [12.5 24.0]; + fontsize = 14; + leg_pos_st = [10.85 7.43 1.3 1.2 ]; + leg_pos_st_x = [14.15 4.35 1.3 1.4 ]; + leg_pos_mt = [10.85 7.66 1.3 1.0 ]; + sp_margins = [ 0.063 0.033 ]; +end + +%fig = figure('Position', [100, 100, 2400, 1500]); +fig = figure('Position', position); +orient( fig, 'portrait' ); +set(gcf,'PaperUnits', 'inches'); +if impl == 'octave' + set(gcf,'PaperSize', papersize); + set(gcf,'PaperPositionMode','auto'); +else % impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); +end +set(gcf,'PaperOrientation','landscape'); + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blisconv = '%s/output_%s_%s_blisconv.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; +filetemp_bfeo = '%s/output_%s_%s_blasfeo.m'; +filetemp_xsmm = '%s/output_%s_%s_libxsmm.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims, ldim_str, pack_str ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + % Remove leading and trailing whitespace. + opsupname = strtrim( opsupname ); + opname = strtrim( opname ); + + % Output progress through the loop. + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Load the data for each dataset. + data_blissup = load_data( filetemp_blissup, dirpath, thr_str, opsupname, vartemp, opname, 'blissup' ); + data_blisconv = load_data( filetemp_blisconv, dirpath, thr_str, opsupname, vartemp, opname, 'blisconv' ); + data_eigen = load_data( filetemp_eigen, dirpath, thr_str, opsupname, vartemp, opname, 'eigen' ); + data_open = load_data( filetemp_open, dirpath, thr_str, opsupname, vartemp, opname, 'openblas' ); + data_vend = load_data( filetemp_vend, dirpath, thr_str, opsupname, vartemp, opname, 'vendor' ); + + % Only read blasfeo data for single-threaded cases. + if nth == 1 + data_bfeo = load_data( filetemp_bfeo, dirpath, thr_str, opsupname, vartemp, opname, 'blasfeo' ); + else + data_bfeo = zeros( size( data_blissup, 1 ), size( data_blissup, 2 ) ); + end + + % Only read libxsmm data for single-threaded cases, and cases that use column + % storage since that's the only format that libxsmm supports. + %if nth == 1 && stor_str == 'ccc' + if nth == 1 && strcmp( stor_str, 'ccc' ) + data_xsmm = load_data( filetemp_xsmm, dirpath, thr_str, opsupname, vartemp, opname, 'libxsmm' ); + else + data_xsmm = zeros( size( data_blissup, 1 ), size( data_blissup, 2 ) ); + end + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + plot_l3sup_perf( opsupname, ... + smalldims, ... + data_blissup, ... + data_blisconv, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + data_bfeo, ... + data_xsmm, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl, ... + fontsize, ... + leg_pos_st, leg_pos_st_x, leg_pos_mt, ... + sp_margins ); +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +if strcmp( impl, 'octave' ) + print(gcf, outfile); +else + print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/octave/runthese.m b/test/sup/octave/runthese.m new file mode 100644 index 0000000000..5ffc0c2138 --- /dev/null +++ b/test/sup/octave/runthese.m @@ -0,0 +1,25 @@ +% kabylake +plot_panel_trxsh(3.80,16,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/kabylake/20200302/mnkt100000_st','kbl','MKL'); close; clear all; + +% haswell +plot_panel_trxsh(3.5,16,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/haswell/20200302/mnkt100000_st','has','MKL'); close; clear all; + +% zen +plot_panel_trxsh(3.00, 8,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/epyc/20200302/mnkt100000_st','zen','MKL'); close; clear all; + +% zen2 +plot_panel_trxsh(3.40,16, 1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; +plot_panel_trxsh(3.40,16, 1,'st','d','ccc',[ 6 8 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; +plot_panel_trxsh(3.40,16, 1,'st','s','rrr',[ 6 16 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; +plot_panel_trxsh(3.40,16, 1,'st','s','ccc',[ 6 16 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; + +plot_panel_trxsh(2.60,16,32,'mt','d','rrr',[ 6 8 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; +plot_panel_trxsh(2.60,16,32,'mt','d','ccc',[ 6 8 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; +plot_panel_trxsh(2.60,16,32,'mt','s','rrr',[ 6 16 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; +plot_panel_trxsh(2.60,16,32,'mt','s','ccc',[ 6 16 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; + + + +plot_panel_trxsh(3.40,16, 1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; plot_panel_trxsh(3.40,16, 1,'st','d','ccc',[ 6 8 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; plot_panel_trxsh(3.40,16, 1,'st','s','rrr',[ 6 16 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; plot_panel_trxsh(3.40,16, 1,'st','s','ccc',[ 6 16 4 ],'lds','uaub','../results/zen2/20201006/mnkt100000_st', 'zen2','MKL'); close; clear all; + +plot_panel_trxsh(2.60,16,32,'mt','d','rrr',[ 6 8 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; plot_panel_trxsh(2.60,16,32,'mt','d','ccc',[ 6 8 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; plot_panel_trxsh(2.60,16,32,'mt','s','rrr',[ 6 16 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; plot_panel_trxsh(2.60,16,32,'mt','s','ccc',[ 6 16 10 ],'lds','uaub','../results/zen2/20201006/mnkt100000_mt32','zen2','MKL'); close; clear all; diff --git a/test/sup/octave/subplot_tight.m b/test/sup/octave/subplot_tight.m new file mode 100644 index 0000000000..d84ea31888 --- /dev/null +++ b/test/sup/octave/subplot_tight.m @@ -0,0 +1,126 @@ +% +% Copyright (c) 2016, Nikolay S. +% All rights reserved. +% +% Redistribution and use in source and binary forms, with or without +% modification, are permitted provided that the following conditions are +% met: +% +% * Redistributions of source code must retain the above copyright +% notice, this list of conditions and the following disclaimer. +% * Redistributions in binary form must reproduce the above copyright +% notice, this list of conditions and the following disclaimer in +% the documentation and/or other materials provided with the distribution +% +% THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +% AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +% IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +% ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +% LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +% CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +% SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +% INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +% CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +% ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +% POSSIBILITY OF SUCH DAMAGE. +% + +function vargout=subplot_tight(m, n, p, margins, varargin) +%% subplot_tight +% A subplot function substitude with margins user tunabble parameter. +% +%% Syntax +% h=subplot_tight(m, n, p); +% h=subplot_tight(m, n, p, margins); +% h=subplot_tight(m, n, p, margins, subplotArgs...); +% +%% Description +% Our goal is to grant the user the ability to define the margins between neighbouring +% subplots. Unfotrtunately Matlab subplot function lacks this functionality, and the +% margins between subplots can reach 40% of figure area, which is pretty lavish. While at +% the begining the function was implememnted as wrapper function for Matlab function +% subplot, it was modified due to axes del;etion resulting from what Matlab subplot +% detected as overlapping. Therefore, the current implmenetation makes no use of Matlab +% subplot function, using axes instead. This can be problematic, as axis and subplot +% parameters are quie different. Set isWrapper to "True" to return to wrapper mode, which +% fully supports subplot format. +% +%% Input arguments (defaults exist): +% margins- two elements vector [vertical,horizontal] defining the margins between +% neighbouring axes. Default value is 0.04 +% +%% Output arguments +% same as subplot- none, or axes handle according to function call. +% +%% Issues & Comments +% - Note that if additional elements are used in order to be passed to subplot, margins +% parameter must be defined. For default margins value use empty element- []. +% - +% +%% Example +% close all; +% img=imread('peppers.png'); +% figSubplotH=figure('Name', 'subplot'); +% figSubplotTightH=figure('Name', 'subplot_tight'); +% nElems=17; +% subplotRows=ceil(sqrt(nElems)-1); +% subplotRows=max(1, subplotRows); +% subplotCols=ceil(nElems/subplotRows); +% for iElem=1:nElems +% figure(figSubplotH); +% subplot(subplotRows, subplotCols, iElem); +% imshow(img); +% figure(figSubplotTightH); +% subplot_tight(subplotRows, subplotCols, iElem, [0.0001]); +% imshow(img); +% end +% +%% See also +% - subplot +% +%% Revision history +% First version: Nikolay S. 2011-03-29. +% Last update: Nikolay S. 2012-05-24. +% +% *List of Changes:* +% 2012-05-24 +% Non wrapping mode (based on axes command) added, to deal with an issue of disappearing +% subplots occuring with massive axes. + +%% Default params +isWrapper=false; +if (nargin<4) || isempty(margins) + margins=[0.04,0.04]; % default margins value- 4% of figure +end +if length(margins)==1 + margins(2)=margins; +end + +%note n and m are switched as Matlab indexing is column-wise, while subplot indexing is row-wise :( +[subplot_col,subplot_row]=ind2sub([n,m],p); + + +height=(1-(m+1)*margins(1))/m; % single subplot height +width=(1-(n+1)*margins(2))/n; % single subplot width + +% note subplot suppors vector p inputs- so a merged subplot of higher dimentions will be created +subplot_cols=1+max(subplot_col)-min(subplot_col); % number of column elements in merged subplot +subplot_rows=1+max(subplot_row)-min(subplot_row); % number of row elements in merged subplot + +merged_height=subplot_rows*( height+margins(1) )- margins(1); % merged subplot height +merged_width= subplot_cols*( width +margins(2) )- margins(2); % merged subplot width + +merged_bottom=(m-max(subplot_row))*(height+margins(1)) +margins(1); % merged subplot bottom position +merged_left=min(subplot_col)*(width+margins(2))-width; % merged subplot left position +pos=[merged_left, merged_bottom, merged_width, merged_height]; + + +if isWrapper + h=subplot(m, n, p, varargin{:}, 'Units', 'Normalized', 'Position', pos); +else + h=axes('Position', pos, varargin{:}); +end + +if nargout==1 + vargout=h; +end diff --git a/test/sup/old/octave_mt/gen_opsupnames.m b/test/sup/old/octave_mt/gen_opsupnames.m new file mode 100644 index 0000000000..40258677d0 --- /dev/null +++ b/test/sup/old/octave_mt/gen_opsupnames.m @@ -0,0 +1,50 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims, ldim, pack ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + % sprintf'ing directly into an array of strings, as in: + % + % opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp_%s_%s', ... ); + % + % doesn't work when the string lengths as they would if any of the constant + % dimensions is greater than 9. + str0 = sprintf( '%s_%s_m%dnpkp_%s_%s', op, stor, smallm, ldim, pack ); + str1 = sprintf( '%s_%s_mpn%dkp_%s_%s', op, stor, smalln, ldim, pack ); + str2 = sprintf( '%s_%s_mpnpk%d_%s_%s', op, stor, smallk, ldim, pack ); + str3 = sprintf( '%s_%s_mpn%dk%d_%s_%s', op, stor, smalln, smallk, ldim, pack ); + str4 = sprintf( '%s_%s_m%dnpk%d_%s_%s', op, stor, smallm, smallk, ldim, pack ); + str5 = sprintf( '%s_%s_m%dn%dkp_%s_%s', op, stor, smallm, smalln, ldim, pack ); + str6 = sprintf( '%s_%s_mpnpkp_%s_%s', op, stor, ldim, pack ); + + opsupnames( i+0, : ) = sprintf( '%-31s', str0 ); + opsupnames( i+1, : ) = sprintf( '%-31s', str1 ); + opsupnames( i+2, : ) = sprintf( '%-31s', str2 ); + opsupnames( i+3, : ) = sprintf( '%-31s', str3 ); + opsupnames( i+4, : ) = sprintf( '%-31s', str4 ); + opsupnames( i+5, : ) = sprintf( '%-31s', str5 ); + opsupnames( i+6, : ) = sprintf( '%-31s', str6 ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + diff --git a/test/sup/old/octave_mt/plot_l3sup_perf.m b/test/sup/old/octave_mt/plot_l3sup_perf.m new file mode 100644 index 0000000000..43a05e87b2 --- /dev/null +++ b/test/sup/old/octave_mt/plot_l3sup_perf.m @@ -0,0 +1,274 @@ +function r_val = plot_l3sup_perf( opname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl ) + +%if ... %mod(theid-1,cols) == 2 || ... +% ... %mod(theid-1,cols) == 3 || ... +% ... %mod(theid-1,cols) == 4 || ... +% 0 == 1 ... %theid >= 19 +% show_plot = 0; +%else + show_plot = 1; +%end + +%legend_plot_id = 11; +legend_plot_id = 0*cols + 1*6; + +if 1 + ax1 = subplot( rows, cols, theid ); + hold( ax1, 'on' ); +end + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blislpab = 'k'; lines_blislpab = ':'; markr_blislpab = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_legend = sprintf( 'BLIS sup' ); +blislpab_legend = sprintf( 'BLIS conv' ); +eigen_legend = sprintf( 'Eigen' ); +open_legend = sprintf( 'OpenBLAS' ); +%vend_legend = sprintf( 'MKL' ); +%vend_legend = sprintf( 'ARMPL' ); +vend_legend = vend_str; + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + + +%flopscol = 4; +flopscol = size( data_blissup, 2 ); +msize = 5; +if 1 + fontsize = 12; +else + fontsize = 16; +end +linesize = 0.5; +legend_loc = 'southeast'; + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that we +% only use half the data points for the m = n = k column of graphs. +%if mod(theid-1,cols) == 6 +% np = size( data_blissup, 1 ) / 2; +%else +% np = size( data_blissup, 1 ); +%end +np = size( data_blissup, 1 ); + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +if show_plot == 1 +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( x_axis( 1:np, 1 ), data_blislpab( 1:np, flopscol ) / nth, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +elseif theid == legend_plot_id +blissup_ln = line( nan, nan, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( nan, nan, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( nan, nan, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( nan, nan, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +vend_ln = line( nan, nan, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if mod(theid-1,cols) == 3 || mod(theid-1,cols) == 4 || mod(theid-1,cols) == 5 + if nth == 12 + ylim( ax1, [y_begin y_end/2] ); + elseif nth > 12 + ylim( ax1, [y_begin y_end/6] ); + end +end + +if 10000 <= x_end && x_end < 15000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 4000 8000 12000 ] ); +elseif 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + +if show_plot == 1 || theid == legend_plot_id + if theid == legend_plot_id + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + %set( leg,'Position',[12.40 10.60 1.9 0.95 ] ); % (1,4tl) + set( leg,'Position',[18.80 10.60 1.9 0.95 ] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7 ] ); % (1,4tl) + end + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) + end +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +% The default is to align the plot title across whole figure, not the box. +% This is a hack to nudge the title back to the center of the box. +if impl == 'octave' + tpos = get( titl, 'Position' ); + % For some reason, the titles in the graphs in the last column start + % off in a different relative position than the graphs in the other + % columns. Here, we manually account for that. + if mod(theid-1,cols) == 6 + tpos(1) = tpos(1) + -10; + else + tpos(1) = tpos(1) + -40; + end + set( titl, 'Position', tpos ); + set( titl, 'FontSize', fontsize ); +else % impl == 'matlab' + tpos = get( titl, 'Position' ); + tpos(1) = tpos(1) + 90; + set( titl, 'Position', tpos ); +end + +if theid > (rows-1)*cols + %xlab = xlabel( ax1,xaxisname ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, 'm = 6; n = k' ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, 'n = 8; m = k' ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, 'k = 10; m = n' ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, 'm; n = 8, k = 10' ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, 'n; m = 6, k = 10' ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, 'k; m = 6, n = 8' ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, 'm = n = k' ); + end +end + +if mod(theid-1,cols) == 0 + ylab = ylabel( ax1,yaxisname ); +end + +r_val = 0; + diff --git a/test/sup/old/octave_mt/plot_panel_trxsh.m b/test/sup/old/octave_mt/plot_panel_trxsh.m new file mode 100644 index 0000000000..d890d0dd82 --- /dev/null +++ b/test/sup/old/octave_mt/plot_panel_trxsh.m @@ -0,0 +1,146 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + ldim_str, ... + pack_str, ... + dirpath, ... + arch_str, ... + vend_str, ... + impl ... + ) + +%cfreq = 1.8; +%dflopspercycle = 32; + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blisconv = '%s/output_%s_%s_blisconv.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims, ldim_str, pack_str ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +if 1 == 1 + %fig = figure('Position', [100, 100, 2400, 1500]); + fig = figure('Position', [100, 100, 2400, 1200]); + orient( fig, 'portrait' ); + set(gcf,'PaperUnits', 'inches'); + if impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); + else % impl == 'octave' % octave 4.x + set(gcf,'PaperSize', [12 22.0]); + set(gcf,'PaperPositionMode','auto'); + end + set(gcf,'PaperOrientation','landscape'); +end + + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + opsupname = strtrim( opsupname ); + opname = strtrim( opname ); + + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Construct filenames for the data files from templates. + file_blissup = sprintf( filetemp_blissup, dirpath, thr_str, opsupname ); + file_blisconv = sprintf( filetemp_blisconv, dirpath, thr_str, opsupname ); + file_eigen = sprintf( filetemp_eigen, dirpath, thr_str, opsupname ); + file_open = sprintf( filetemp_open, dirpath, thr_str, opsupname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opsupname ); + + % Load the data files. + %str = sprintf( ' Loading %s', file_blissup ); disp(str); + run( file_blissup ) + run( file_blisconv ) + run( file_eigen ) + run( file_open ) + run( file_vend ) + + % Construct variable names for the variables in the data files. + var_blissup = sprintf( vartemp, thr_str, opname, 'blissup' ); + var_blisconv = sprintf( vartemp, thr_str, opname, 'blisconv' ); + var_eigen = sprintf( vartemp, thr_str, opname, 'eigen' ); + var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data_blissup = eval( var_blissup ); % e.g. data_st_dgemm_blissup( :, : ); + data_blisconv = eval( var_blisconv ); % e.g. data_st_dgemm_blisconv( :, : ); + data_eigen = eval( var_eigen ); % e.g. data_st_dgemm_eigen( :, : ); + data_open = eval( var_open ); % e.g. data_st_dgemm_openblas( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_dgemm_vendor( :, : ); + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + if 1 == 1 + plot_l3sup_perf( opsupname, ... + data_blissup, ... + data_blisconv, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl ); + + clear data_mt_*gemm_*; + clear data_blissup; + clear data_blisconv; + clear data_eigen; + clear data_open; + clear data_vend; + + end + +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +%print(gcf, 'gemm_md','-fillpage','-dpdf'); +%print(gcf, outfile,'-bestfit','-dpdf'); +if impl == 'octave' + print(gcf, outfile); +else % if impl == 'matlab' + print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/old/octave_mt/runthese.m b/test/sup/old/octave_mt/runthese.m new file mode 100644 index 0000000000..00ab181d34 --- /dev/null +++ b/test/sup/old/octave_mt/runthese.m @@ -0,0 +1,8 @@ +% kabylake +plot_panel_trxsh(3.80,16,4,'mt','d','rrr',[ 6 8 10 ],'lds','uaub','../results/kabylake/20200302/mnkt100000_mt4','kbl','MKL','octave'); close; clear all; + +% haswell +plot_panel_trxsh(3.1,16,12,'mt','d','rrr',[ 6 8 10 ],'lds','uaub','../results/haswell/20200302/mnkt100000_mt12','has','MKL','octave'); close; clear all; + +% epyc +plot_panel_trxsh(2.55,8,32,'mt','d','rrr',[ 6 8 10 ],'lds','uaub','../results/epyc/20200302/mnkt100000_mt32','epyc','MKL','octave'); close; clear all; diff --git a/test/sup/old/octave_st/gen_opsupnames.m b/test/sup/old/octave_st/gen_opsupnames.m new file mode 100644 index 0000000000..e5dee05402 --- /dev/null +++ b/test/sup/old/octave_st/gen_opsupnames.m @@ -0,0 +1,36 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims, ldim, pack ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp_%s_%s', op, stor, smallm, ldim, pack ); + opsupnames( i+1, : ) = sprintf( '%s_%s_mpn%dkp_%s_%s', op, stor, smalln, ldim, pack ); + opsupnames( i+2, : ) = sprintf( '%s_%s_mpnpk%d_%s_%s', op, stor, smallk, ldim, pack ); + opsupnames( i+3, : ) = sprintf( '%s_%s_mpn%dk%d_%s_%s', op, stor, smalln, smallk, ldim, pack ); + opsupnames( i+4, : ) = sprintf( '%s_%s_m%dnpk%d_%s_%s', op, stor, smallm, smallk, ldim, pack ); + opsupnames( i+5, : ) = sprintf( '%s_%s_m%dn%dkp_%s_%s', op, stor, smallm, smalln, ldim, pack ); + opsupnames( i+6, : ) = sprintf( '%s_%s_mpnpkp_%s_%s', op, stor, ldim, pack ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + diff --git a/test/sup/old/octave_st/plot_l3sup_perf.m b/test/sup/old/octave_st/plot_l3sup_perf.m new file mode 100644 index 0000000000..8a615ada52 --- /dev/null +++ b/test/sup/old/octave_st/plot_l3sup_perf.m @@ -0,0 +1,328 @@ +function r_val = plot_l3sup_perf( opname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_bfeo, ... + data_xsmm, ... + data_vend, vend_str, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl ) + +%if ... %mod(theid-1,cols) == 2 || ... +% ... %mod(theid-1,cols) == 3 || ... +% ... %mod(theid-1,cols) == 4 || ... +% 0 == 1 ... %theid >= 19 +% show_plot = 0; +%else + show_plot = 1; +%end + +%legend_plot_id = 11; +legend_plot_id = 2*cols + 1*5; + +if 1 + ax1 = subplot( rows, cols, theid ); + hold( ax1, 'on' ); +end + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blislpab = 'k'; lines_blislpab = ':'; markr_blislpab = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_bfeo = 'c'; lines_bfeo = '-'; markr_bfeo = 'o'; +color_xsmm = 'g'; lines_xsmm = '-'; markr_xsmm = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_legend = sprintf( 'BLIS sup' ); +blislpab_legend = sprintf( 'BLIS conv' ); +eigen_legend = sprintf( 'Eigen' ); +open_legend = sprintf( 'OpenBLAS' ); +bfeo_legend = sprintf( 'BLASFEO' ); +xsmm_legend = sprintf( 'libxsmm' ); +%vend_legend = sprintf( 'MKL' ); +%vend_legend = sprintf( 'ARMPL' ); +vend_legend = vend_str; + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + + +%flopscol = 4; +flopscol = size( data_blissup, 2 ); +msize = 5; +if 1 + fontsize = 12; +else + fontsize = 16; +end +linesize = 0.5; +legend_loc = 'southeast'; + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that we +% only use half the data points for the m = n = k column of graphs. +%if mod(theid-1,cols) == 6 +% np = size( data_blissup, 1 ) / 2; +%else +% np = size( data_blissup, 1 ); +%end +np = size( data_blissup, 1 ); + +has_xsmm = 1; +if data_xsmm( 1, flopscol ) == 0.0 + has_xsmm = 0; +end + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +if show_plot == 1 +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( x_axis( 1:np, 1 ), data_blislpab( 1:np, flopscol ) / nth, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +bfeo_ln = line( x_axis( 1:np, 1 ), data_bfeo( 1:np, flopscol ) / nth, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +if has_xsmm == 1 +xsmm_ln = line( x_axis( 1:np, 1 ), data_xsmm( 1:np, flopscol ) / nth, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +else +xsmm_ln = line( nan, nan, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +end +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +elseif theid == legend_plot_id +blissup_ln = line( nan, nan, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( nan, nan, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( nan, nan, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( nan, nan, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +bfeo_ln = line( nan, nan, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +xsmm_ln = line( nan, nan, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +vend_ln = line( nan, nan, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if 10000 <= x_end && x_end < 15000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 3000 6000 9000 12000 ] ); +elseif 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + %xticks( ax1, [ x_tick1 x_tick2 ] ); + xticks( ax1, [ 2000 4000 6000 8000 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + +if show_plot == 1 || theid == legend_plot_id + if nth == 1 && theid == legend_plot_id + if has_xsmm == 1 + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + bfeo_ln ... + xsmm_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + bfeo_legend, ... + xsmm_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[15.40 4.75 1.9 1.20] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-3 ); + set( leg,'Position',[18.20 10.20 1.15 0.7 ] ); % (1,4tl) + end + else + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + bfeo_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + bfeo_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[15.40 7.65 1.9 1.10] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7] ); % (1,4tl) + end + end + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) + elseif nth > 1 && theid == legend_plot_id + end +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +% The default is to align the plot title across whole figure, not the box. +% This is a hack to nudge the title back to the center of the box. +if impl == 'octave' + tpos = get( titl, 'Position' ); + % For some reason, the titles in the graphs in the last column start + % off in a different relative position than the graphs in the other + % columns. Here, we manually account for that. + if mod(theid-1,cols) == 6 + tpos(1) = tpos(1) + -10; + else + tpos(1) = tpos(1) + -40; + end + set( titl, 'Position', tpos ); + set( titl, 'FontSize', fontsize ); +else % impl == 'matlab' + tpos = get( titl, 'Position' ); + tpos(1) = tpos(1) + 90; + set( titl, 'Position', tpos ); +end + +if theid > (rows-1)*cols + %xlab = xlabel( ax1,xaxisname ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, 'm = 6; n = k' ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, 'n = 8; m = k' ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, 'k = 4; m = n' ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, 'm; n = 8, k = 4' ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, 'n; m = 6, k = 4' ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, 'k; m = 6, n = 8' ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, 'm = n = k' ); + end +end + +if mod(theid-1,cols) == 0 + ylab = ylabel( ax1,yaxisname ); +end + +r_val = 0; + diff --git a/test/sup/old/octave_st/plot_panel_trxsh.m b/test/sup/old/octave_st/plot_panel_trxsh.m new file mode 100644 index 0000000000..a4bd2fb59f --- /dev/null +++ b/test/sup/old/octave_st/plot_panel_trxsh.m @@ -0,0 +1,167 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + ldim_str, ... + pack_str, ... + dirpath, ... + arch_str, ... + vend_str, ... + impl ... + ) + +%cfreq = 1.8; +%dflopspercycle = 32; + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blisconv = '%s/output_%s_%s_blisconv.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_bfeo = '%s/output_%s_%s_blasfeo.m'; +filetemp_xsmm = '%s/output_%s_%s_libxsmm.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims, ldim_str, pack_str ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +if 1 == 1 + %fig = figure('Position', [100, 100, 2400, 1500]); + fig = figure('Position', [100, 100, 2400, 1200]); + orient( fig, 'portrait' ); + set(gcf,'PaperUnits', 'inches'); + if impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); + else % impl == 'octave' % octave 4.x + set(gcf,'PaperSize', [12 22.0]); + set(gcf,'PaperPositionMode','auto'); + end + set(gcf,'PaperOrientation','landscape'); +end + + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Construct filenames for the data files from templates. + file_blissup = sprintf( filetemp_blissup, dirpath, thr_str, opsupname ); + file_blisconv = sprintf( filetemp_blisconv, dirpath, thr_str, opsupname ); + file_eigen = sprintf( filetemp_eigen, dirpath, thr_str, opsupname ); + file_open = sprintf( filetemp_open, dirpath, thr_str, opsupname ); + file_bfeo = sprintf( filetemp_bfeo, dirpath, thr_str, opsupname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opsupname ); + + % Load the data files. + %str = sprintf( ' Loading %s', file_blissup ); disp(str); + run( file_blissup ) + run( file_blisconv ) + run( file_eigen ) + run( file_open ) + run( file_bfeo ) + run( file_vend ) + + % Construct variable names for the variables in the data files. + var_blissup = sprintf( vartemp, thr_str, opname, 'blissup' ); + var_blisconv = sprintf( vartemp, thr_str, opname, 'blisconv' ); + var_eigen = sprintf( vartemp, thr_str, opname, 'eigen' ); + var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); + var_bfeo = sprintf( vartemp, thr_str, opname, 'blasfeo' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data_blissup = eval( var_blissup ); % e.g. data_st_dgemm_blissup( :, : ); + data_blisconv = eval( var_blisconv ); % e.g. data_st_dgemm_blisconv( :, : ); + data_eigen = eval( var_eigen ); % e.g. data_st_dgemm_eigen( :, : ); + data_open = eval( var_open ); % e.g. data_st_dgemm_openblas( :, : ); + data_bfeo = eval( var_bfeo ); % e.g. data_st_dgemm_blasfeo( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_dgemm_vendor( :, : ); + + if stor_str == 'ccc' + % Only read xsmm data for the column storage case, since that's the + % only format that libxsmm supports. + file_xsmm = sprintf( filetemp_xsmm, dirpath, thr_str, opsupname ); + run( file_xsmm ) + var_xsmm = sprintf( vartemp, thr_str, opname, 'libxsmm' ); + data_xsmm = eval( var_xsmm ); % e.g. data_st_dgemm_libxsmm( :, : ); + else + % Set the data variable to zeros using the same dimensions as the other + % variables. + data_xsmm = zeros( size( data_blissup, 1 ), ... + size( data_blissup, 2 ) ); + end + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + if 1 == 1 + plot_l3sup_perf( opsupname, ... + data_blissup, ... + data_blisconv, ... + data_eigen, ... + data_open, ... + data_bfeo, ... + data_xsmm, ... + data_vend, vend_str, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl ); + + clear data_st_*gemm_*; + clear data_blissup; + clear data_blisconv; + clear data_eigen; + clear data_open; + clear data_bfeo; + clear data_xsmm; + clear data_vend; + + end + +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +%print(gcf, 'gemm_md','-fillpage','-dpdf'); +%print(gcf, outfile,'-bestfit','-dpdf'); +if impl == 'octave' + print(gcf, outfile); +else % if impl == 'matlab' + print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/old/octave_st/runthese.m b/test/sup/old/octave_st/runthese.m new file mode 100644 index 0000000000..8e3519f33b --- /dev/null +++ b/test/sup/old/octave_st/runthese.m @@ -0,0 +1,8 @@ +% kabylake +plot_panel_trxsh(3.80,16,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/kabylake/20200302/mnkt100000_st','kbl','MKL','octave'); close; clear all; + +% haswell +plot_panel_trxsh(3.5,16,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/haswell/20200302/mnkt100000_st','has','MKL','octave'); close; clear all; + +% epyc +plot_panel_trxsh(3.00, 8,1,'st','d','rrr',[ 6 8 4 ],'lds','uaub','../results/epyc/20200302/mnkt100000_st','epyc','MKL','octave'); close; clear all; diff --git a/test/sup/old/supmt/Makefile b/test/sup/old/supmt/Makefile new file mode 100644 index 0000000000..5004aff77a --- /dev/null +++ b/test/sup/old/supmt/Makefile @@ -0,0 +1,636 @@ +#!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + st mt \ + blissup-st blislpab-st eigen-st openblas-st vendor-st blasfeo-st libxsmm-st \ + blissup-mt blislpab-mt eigen-mt openblas-mt vendor-mt \ + clean cleanx + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + +# +# --- BLAS and LAPACK implementations ------------------------------------------ +# + +# BLIS library and header path. This is simply wherever it was installed. +#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib +#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis + +# BLIS library. +#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a + +# BLAS library path(s). This is where the BLAS libraries reside. +HOME_LIB_PATH := $(HOME)/flame/lib +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 + +# netlib BLAS +NETLIB_LIB := $(HOME_LIB_PATH)/libblas.a + +# OpenBLAS +OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a +OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a + +# BLASFEO +BLASFEO_LIB := $(HOME_LIB_PATH)/libblasfeo.a + +# libxsmm +LIBXSMM_LIB := $(HOME_LIB_PATH)/libxsmm.a -ldl \ + $(NETLIB_LIB) -lgfortran + +# ATLAS +ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ + $(HOME_LIB_PATH)/libatlas.a + +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + +# MKL +MKL_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_sequential \ + -lpthread -lm -ldl +MKLP_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_gnu_thread \ + -lpthread -lm -ldl -fopenmp + #-L$(ICC_LIB_PATH) \ + #-lgomp + +VENDOR_LIB := $(MKL_LIB) +VENDORP_LIB := $(MKLP_LIB) + + +# +# --- Problem size definitions ------------------------------------------------- +# + +# The problem size range specification is done separately for single-threaded +# and multithreaded execution. Within each threadedness scenario, we allow for +# separate range specifications for cases with: +# - 3L: three large/variable dimensions and no small/constant dimensions +# - 2L: two large/variable dimensions and one small/constant dimension +# - 1L: one large/variable dimension and two small/constant dimensions + +# -- Single-threaded -- + +PS_BEGIN_3L := 2 +PS_MAX_3L := 400 +PS_INC_3L := 2 + +PS_BEGIN_2L := 4 +PS_MAX_2L := 800 +PS_INC_2L := 4 + +PS_BEGIN_1L := 32 +PS_MAX_1L := 6400 +PS_INC_1L := 32 + +# -- Multithreaded -- + +P1_BEGIN_3L := 4 +P1_MAX_3L := 800 +P1_INC_3L := 4 + +P1_BEGIN_2L := 8 +P1_MAX_2L := 1600 +P1_INC_2L := 8 + +P1_BEGIN_1L := 64 +P1_MAX_1L := 12800 +P1_INC_1L := 64 + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c))) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-frame-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +# Use the "framework" CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS. +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) + +# Single or multithreaded string. +STR_ST := -DTHR_STR=\"st\" +STR_MT := -DTHR_STR=\"mt\" + +# Number of trials per problem size. +N_TRIALS := -DN_TRIALS=3 + +# Problem size specification. +PDEF_ST_1L := -DP_BEGIN=$(PS_BEGIN_1L) -DP_MAX=$(PS_MAX_1L) -DP_INC=$(PS_INC_1L) +PDEF_ST_2L := -DP_BEGIN=$(PS_BEGIN_2L) -DP_MAX=$(PS_MAX_2L) -DP_INC=$(PS_INC_2L) +PDEF_ST_3L := -DP_BEGIN=$(PS_BEGIN_3L) -DP_MAX=$(PS_MAX_3L) -DP_INC=$(PS_INC_3L) + +PDEF_MT_1L := -DP_BEGIN=$(P1_BEGIN_1L) -DP_MAX=$(P1_MAX_1L) -DP_INC=$(P1_INC_1L) +PDEF_MT_2L := -DP_BEGIN=$(P1_BEGIN_2L) -DP_MAX=$(P1_MAX_2L) -DP_INC=$(P1_INC_2L) +PDEF_MT_3L := -DP_BEGIN=$(P1_BEGIN_3L) -DP_MAX=$(P1_MAX_3L) -DP_INC=$(P1_INC_3L) + +ifeq ($(E),1) +ERRCHK := -DERROR_CHECK +else +ERRCHK := -DNO_ERROR_CHECK +endif + +# Enumerate possible datatypes and computation precisions. +#dts := s d c z +DTS := d + +TRANS := n_n \ + n_t \ + t_n \ + t_t + +# While BLIS supports all combinations of row and column storage for matrices +# C, A, and B, the alternatives mostly only support CBLAS APIs, which inherently +# support only "all row-storage" or "all column-storage". Thus, we disable the +# building of those other drivers so that compilation/linking completes sooner. +#STORS := r_r_r \ +# r_r_c \ +# r_c_r \ +# r_c_c \ +# c_r_r \ +# c_r_c \ +# c_c_r \ +# c_c_c +STORS := r_r_r \ + c_c_c + + +SHAPES := l_l_s \ + l_s_l \ + s_l_l \ + s_s_l \ + s_l_s \ + l_s_s \ + l_l_l + +# Define the small/constant m, n, and k dimensions for single core and multicore +# experiments. +SMS_ST := 6 +SNS_ST := 8 +SKS_ST := 4 + +SMS_MT := 6 +SNS_MT := 8 +SKS_MT := 10 + + +# +# --- Function definitions ----------------------------------------------------- +# + +# A function to strip the underscores from a list of strings. +stripu = $(subst _,,$(1)) + +# Various functions that help us construct the datatype combinations and then +# extract the needed datatype strings and C preprocessor define flags. +get-1of2 = $(word 1,$(subst _, ,$(1))) +get-2of2 = $(word 2,$(subst _, ,$(1))) + +get-1of3 = $(word 1,$(subst _, ,$(1))) +get-2of3 = $(word 2,$(subst _, ,$(1))) +get-3of3 = $(word 3,$(subst _, ,$(1))) + +# A function to return the correct PDEFS_ST variable given the shape string. +get-pdefs = $(strip $(subst l_l_l,$(PDEF_MT_3L), \ + $(subst l_l_s,$(PDEF_MT_2L), \ + $(subst l_s_l,$(PDEF_MT_2L), \ + $(subst s_l_l,$(PDEF_MT_2L), \ + $(subst s_s_l,$(PDEF_MT_1L), \ + $(subst s_l_s,$(PDEF_MT_1L), \ + $(subst l_s_s,$(PDEF_MT_1L),$(1))))))))) + +# Datatype defs. +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) + +# Transpose defs. +get-tra-defs-a = $(strip $(subst n,-DTRANSA=BLIS_NO_TRANSPOSE -DA_NOTRANS, \ + $(subst t,-DTRANSA=BLIS_TRANSPOSE -DA_TRANS,$(call get-1of2,$(1))))) +get-tra-defs-b = $(strip $(subst n,-DTRANSB=BLIS_NO_TRANSPOSE -DB_NOTRANS, \ + $(subst t,-DTRANSB=BLIS_TRANSPOSE -DB_TRANS,$(call get-2of2,$(1))))) +get-tra-defs = $(call get-tra-defs-a,$(1)) $(call get-tra-defs-b,$(1)) + +# Storage defs. +get-sto-uch-a = $(strip $(subst r,R, \ + $(subst c,C,$(call get-1of3,$(1))))) +get-sto-uch-b = $(strip $(subst r,R, \ + $(subst c,C,$(call get-2of3,$(1))))) +get-sto-uch-c = $(strip $(subst r,R, \ + $(subst c,C,$(call get-3of3,$(1))))) +get-sto-defs = $(strip \ + -DSTOR3=BLIS_$(call get-sto-uch-a,$(1))$(call get-sto-uch-b,$(1))$(call get-sto-uch-c,$(1)) \ + -DA_STOR_$(call get-sto-uch-a,$(1)) \ + -DB_STOR_$(call get-sto-uch-b,$(1)) \ + -DC_STOR_$(call get-sto-uch-c,$(1))) + +# Dimension defs. +get-shape-defs-cm = $(if $(findstring l,$(1)),-DM_DIM=-1,-DM_DIM=$(2)) +get-shape-defs-cn = $(if $(findstring l,$(1)),-DN_DIM=-1,-DN_DIM=$(2)) +get-shape-defs-ck = $(if $(findstring l,$(1)),-DK_DIM=-1,-DK_DIM=$(2)) +get-shape-defs-m = $(call get-shape-defs-cm,$(call get-1of3,$(1)),$(2)) +get-shape-defs-n = $(call get-shape-defs-cn,$(call get-2of3,$(1)),$(2)) +get-shape-defs-k = $(call get-shape-defs-ck,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-defs = $(strip $(call get-shape-defs-m,$(1),$(2)) \ + $(call get-shape-defs-n,$(1),$(3)) \ + $(call get-shape-defs-k,$(1),$(4))) + +#$(error l_l_s 6 8 4 = $(call get-shape-defs,l_l_s,6,8,4)) + +# Shape-dimension string. +get-shape-str-ch = $(if $(findstring l,$(1)),p,$(2)) +get-shape-str-m = $(call get-shape-str-ch,$(call get-1of3,$(1)),$(2)) +get-shape-str-n = $(call get-shape-str-ch,$(call get-2of3,$(1)),$(2)) +get-shape-str-k = $(call get-shape-str-ch,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-dim-str = m$(call get-shape-str-m,$(1),$(2))n$(call get-shape-str-n,$(1),$(3))k$(call get-shape-str-k,$(1),$(4)) + +# Implementation defs. +# Define a function to return the appropriate -DSTR= and -D[BLIS|BLAS] flags. +get-imp-defs = $(strip $(subst blissup,-DSTR=\"$(1)\" -DBLIS -DSUP, \ + $(subst blislpab,-DSTR=\"$(1)\" -DBLIS, \ + $(subst eigen,-DSTR=\"$(1)\" -DEIGEN, \ + $(subst openblas,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst blasfeo,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst libxsmm,-DSTR=\"$(1)\" -DBLAS -DXSMM, \ + $(subst vendor,-DSTR=\"$(1)\" -DCBLAS,$(1))))))))) + +TRANS0 = $(call stripu,$(TRANS)) +STORS0 = $(call stripu,$(STORS)) + +# Limit BLAS and Eigen to only using all row-stored, or all column-stored matrices. +# Also, limit libxsmm to using all column-stored matrices since it does not offer +# CBLAS interfaces. +BSTORS0 = rrr ccc +ESTORS0 = rrr ccc +XSTORS0 = ccc + + +# +# --- Object and binary file definitons ---------------------------------------- +# + +# -- Single-threaded -- + +get-st-objs = $(foreach dt,$(1),$(foreach tr,$(2),$(foreach st,$(3),$(foreach sh,$(4),$(foreach sm,$(5),$(foreach sn,$(6),$(foreach sk,$(7),test_$(dt)gemm_$(tr)_$(st)_$(call get-shape-dim-str,$(sh),$(sm),$(sn),$(sk))_$(8)_st.o))))))) + +# Build a list of object files and binaries for each single-threaded +# implementation using the get-st-objs() function defined above. +BLISSUP_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),blissup) +BLISSUP_ST_BINS := $(patsubst %.o,%.x,$(BLISSUP_ST_OBJS)) + +BLISLPAB_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),blislpab) +BLISLPAB_ST_BINS := $(patsubst %.o,%.x,$(BLISLPAB_ST_OBJS)) + +EIGEN_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(ESTORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),eigen) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) + +OPENBLAS_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),openblas) +OPENBLAS_ST_BINS := $(patsubst %.o,%.x,$(OPENBLAS_ST_OBJS)) + +BLASFEO_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),blasfeo) +BLASFEO_ST_BINS := $(patsubst %.o,%.x,$(BLASFEO_ST_OBJS)) + +LIBXSMM_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(XSTORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),libxsmm) +LIBXSMM_ST_BINS := $(patsubst %.o,%.x,$(LIBXSMM_ST_OBJS)) + +VENDOR_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS_ST),$(SNS_ST),$(SKS_ST),vendor) +VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLISSUP_ST_OBJS) \ + $(BLISLPAB_ST_OBJS) \ + $(EIGEN_ST_OBJS) \ + $(OPENBLAS_ST_OBJS) \ + $(BLASFEO_ST_OBJS) \ + $(LIBXSMM_ST_OBJS) \ + $(VENDOR_ST_OBJS) + +# -- Multithreaded -- + +get-mt-objs = $(foreach dt,$(1),$(foreach tr,$(2),$(foreach st,$(3),$(foreach sh,$(4),$(foreach sm,$(5),$(foreach sn,$(6),$(foreach sk,$(7),test_$(dt)gemm_$(tr)_$(st)_$(call get-shape-dim-str,$(sh),$(sm),$(sn),$(sk))_$(8)_mt.o))))))) + +# Build a list of object files and binaries for each multithreaded +# implementation using the get-st-objs() function defined above. +BLISSUP_MT_OBJS := $(call get-mt-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS_MT),$(SNS_MT),$(SKS_MT),blissup) +BLISSUP_MT_BINS := $(patsubst %.o,%.x,$(BLISSUP_MT_OBJS)) + +BLISLPAB_MT_OBJS := $(call get-mt-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS_MT),$(SNS_MT),$(SKS_MT),blislpab) +BLISLPAB_MT_BINS := $(patsubst %.o,%.x,$(BLISLPAB_MT_OBJS)) + +EIGEN_MT_OBJS := $(call get-mt-objs,$(DTS),$(TRANS0),$(ESTORS0),$(SHAPES),$(SMS_MT),$(SNS_MT),$(SKS_MT),eigen) +EIGEN_MT_BINS := $(patsubst %.o,%.x,$(EIGEN_MT_OBJS)) + +OPENBLAS_MT_OBJS := $(call get-mt-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS_MT),$(SNS_MT),$(SKS_MT),openblas) +OPENBLAS_MT_BINS := $(patsubst %.o,%.x,$(OPENBLAS_MT_OBJS)) + +VENDOR_MT_OBJS := $(call get-mt-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS_MT),$(SNS_MT),$(SKS_MT),vendor) +VENDOR_MT_BINS := $(patsubst %.o,%.x,$(VENDOR_MT_OBJS)) + +#$(error "objs = $(EIGEN_ST_BINS)" ) + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLISSUP_MT_OBJS) \ + $(BLISLPAB_MT_OBJS) \ + $(EIGEN_MT_OBJS) \ + $(OPENBLAS_MT_OBJS) \ + $(VENDOR_MT_OBJS) + + +# +# --- High-level targets/rules ------------------------------------------------- +# + +all: st + +#blis: blissup-st blislpab-st +#blissup: blissup-st +#blislpab: blislpab-st +#eigen: eigen-st +#openblas: openblas-st +#blasfeo: blasfeo-st +#libxsmm: libxsmm-st +#vendor: vendor-st + +# -- Single-threaded -- + +st: blissup-st blislpab-st \ + eigen-st openblas-st blasfeo-st libxsmm-st vendor-st + +blissup-st: $(BLISSUP_ST_BINS) +blislpab-st: $(BLISLPAB_ST_BINS) +eigen-st: $(EIGEN_ST_BINS) +openblas-st: $(OPENBLAS_ST_BINS) +blasfeo-st: $(BLASFEO_ST_BINS) +libxsmm-st: $(LIBXSMM_ST_BINS) +vendor-st: $(VENDOR_ST_BINS) + +# -- Multithreaded -- + +mt: blissup-mt blislpab-mt \ + eigen-mt openblas-mt vendor-mt + +blissup-mt: $(BLISSUP_MT_BINS) +blislpab-mt: $(BLISLPAB_MT_BINS) +eigen-mt: $(EIGEN_MT_BINS) +openblas-mt: $(OPENBLAS_MT_BINS) +vendor-mt: $(VENDOR_MT_BINS) + + +# --- Object file rules -------------------------------------------------------- + +# Define the implementations for which we will instantiate compilation rules. +BIMPLS_ST := blissup blislpab openblas blasfeo libxsmm vendor +BIMPLS_MT := blissup blislpab openblas vendor +EIMPLS := eigen + +# -- Single-threaded BLAS -- + +# 1 2 3 4 567 8 +# test_dgemm_nn_rrr_mpn6kp_blissup_st.x + +# Define the function that will be used to instantiate compilation rules +# for the various single-threaded implementations. +define make-st-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_st.o: test_gemm.c Makefile + $(CC) $(CFLAGS) $(ERRCHK) $(N_TRIALS) $(call get-pdefs,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each BLIS/BLAS/CBLAS +# implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS_ST), \ +$(foreach sn,$(SNS_ST), \ +$(foreach sk,$(SKS_ST), \ +$(foreach impl,$(BIMPLS_ST), \ +$(eval $(call make-st-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + +# -- Multithreaded BLAS -- + +# Define the function that will be used to instantiate compilation rules +# for the various multithreaded implementations. +define make-mt-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_mt.o: test_gemm.c Makefile + $(CC) $(CFLAGS) $(ERRCHK) $(N_TRIALS) $(call get-pdefs,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_MT) -c $$< -o $$@ +endef + +# Instantiate the rule function make-mt-rule() for each BLIS/BLAS/CBLAS +# implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS_MT), \ +$(foreach sn,$(SNS_MT), \ +$(foreach sk,$(SKS_MT), \ +$(foreach impl,$(BIMPLS_MT), \ +$(eval $(call make-mt-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + +# -- Single-threaded Eigen -- + +# Define the function that will be used to instantiate compilation rules +# for the single-threaded Eigen implementation. +define make-eigst-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_st.o: test_gemm.c Makefile + $(CXX) $(CXXFLAGS_ST) $(ERRCHK) $(N_TRIALS) $(call get-pdefs,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each Eigen implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS_ST), \ +$(foreach sn,$(SNS_ST), \ +$(foreach sk,$(SKS_ST), \ +$(foreach impl,$(EIMPLS), \ +$(eval $(call make-eigst-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + +# -- Multithreaded Eigen -- + +# Define the function that will be used to instantiate compilation rules +# for the multithreaded Eigen implementation. +define make-eigmt-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_mt.o: test_gemm.c Makefile + $(CXX) $(CXXFLAGS_MT) $(ERRCHK) $(N_TRIALS) $(call get-pdefs,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_MT) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each Eigen implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS_MT), \ +$(foreach sn,$(SNS_MT), \ +$(foreach sk,$(SKS_MT), \ +$(foreach impl,$(EIMPLS), \ +$(eval $(call make-eigmt-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + + +# --- Executable file rules ---------------------------------------------------- + +# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS +# on the link command line in case BLIS was configured with the BLAS +# compatibility layer. This prevents BLIS from inadvertently getting called +# for the BLAS routines we are trying to test with. + +# -- Single-threaded -- + +test_%_blissup_st.x: test_%_blissup_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blislpab_st.x: test_%_blislpab_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_eigen_st.x: test_%_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_openblas_st.x: test_%_openblas_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blasfeo_st.x: test_%_blasfeo_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(BLASFEO_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_libxsmm_st.x: test_%_libxsmm_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBXSMM_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_vendor_st.x: test_%_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +# -- Multithreaded -- + +test_%_blissup_mt.x: test_%_blissup_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blislpab_mt.x: test_%_blislpab_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_eigen_mt.x: test_%_eigen_mt.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_openblas_mt.x: test_%_openblas_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_vendor_mt.x: test_%_vendor_mt.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.x *.o + diff --git a/test/sup/old/supmt/octave/gen_opsupnames.m b/test/sup/old/supmt/octave/gen_opsupnames.m new file mode 100644 index 0000000000..a87c06cc27 --- /dev/null +++ b/test/sup/old/supmt/octave/gen_opsupnames.m @@ -0,0 +1,55 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + % NOTE: This way of sprintf'ing doesn't work when the string lengths + % vary, as they would if any of the constant dimensions is greater + % than 9. + %opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp ', op, stor, smallm ) + %opsupnames( i+1, : ) = sprintf( '%s_%s_mpn%dkp ', op, stor, smalln ) + %opsupnames( i+2, : ) = sprintf( '%s_%s_mpnpk%d', op, stor, smallk ) + %opsupnames( i+3, : ) = sprintf( '%s_%s_mpn%dk%d', op, stor, smalln, smallk ) + %opsupnames( i+4, : ) = sprintf( '%s_%s_m%dnpk%d', op, stor, smallm, smallk ) + %opsupnames( i+5, : ) = sprintf( '%s_%s_m%dn%dkp ', op, stor, smallm, smalln ) + %opsupnames( i+6, : ) = sprintf( '%s_%s_mpnpkp ', op, stor ) + + str0 = sprintf( '%s_%s_m%dnpkp', op, stor, smallm ); + str1 = sprintf( '%s_%s_mpn%dkp', op, stor, smalln ); + str2 = sprintf( '%s_%s_mpnpk%d', op, stor, smallk ); + str3 = sprintf( '%s_%s_mpn%dk%d', op, stor, smalln, smallk ); + str4 = sprintf( '%s_%s_m%dnpk%d', op, stor, smallm, smallk ); + str5 = sprintf( '%s_%s_m%dn%dkp', op, stor, smallm, smalln ); + str6 = sprintf( '%s_%s_mpnpkp', op, stor ); + + opsupnames( i+0, : ) = sprintf( '%-22s', str0 ); + opsupnames( i+1, : ) = sprintf( '%-22s', str1 ); + opsupnames( i+2, : ) = sprintf( '%-22s', str2 ); + opsupnames( i+3, : ) = sprintf( '%-22s', str3 ); + opsupnames( i+4, : ) = sprintf( '%-22s', str4 ); + opsupnames( i+5, : ) = sprintf( '%-22s', str5 ); + opsupnames( i+6, : ) = sprintf( '%-22s', str6 ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + diff --git a/test/sup/old/supmt/octave/plot_l3sup_perf.m b/test/sup/old/supmt/octave/plot_l3sup_perf.m new file mode 100644 index 0000000000..d9ecf593f1 --- /dev/null +++ b/test/sup/old/supmt/octave/plot_l3sup_perf.m @@ -0,0 +1,258 @@ +function r_val = plot_l3sup_perf( opname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl ) + +%if ... %mod(theid-1,cols) == 2 || ... +% ... %mod(theid-1,cols) == 3 || ... +% ... %mod(theid-1,cols) == 4 || ... +% 0 == 1 ... %theid >= 19 +% show_plot = 0; +%else + show_plot = 1; +%end + +%legend_plot_id = 11; +legend_plot_id = 0*cols + 1*4; + +if 1 + ax1 = subplot( rows, cols, theid ); + hold( ax1, 'on' ); +end + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blislpab = 'k'; lines_blislpab = ':'; markr_blislpab = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_legend = sprintf( 'BLIS sup' ); +blislpab_legend = sprintf( 'BLIS conv' ); +eigen_legend = sprintf( 'Eigen' ); +open_legend = sprintf( 'OpenBLAS' ); +%vend_legend = sprintf( 'MKL' ); +%vend_legend = sprintf( 'ARMPL' ); +vend_legend = vend_str; + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + + +%flopscol = 4; +flopscol = size( data_blissup, 2 ); +msize = 5; +if 1 + fontsize = 12; +else + fontsize = 16; +end +linesize = 0.5; +legend_loc = 'southeast'; + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that +% we only use quarter the data points for the m = n = k column of graphs. +if mod(theid-1,cols) == 6 + np = size( data_blissup, 1 ) / 4; +else + np = size( data_blissup, 1 ); +end + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +if show_plot == 1 +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( x_axis( 1:np, 1 ), data_blislpab( 1:np, flopscol ) / nth, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +elseif theid == legend_plot_id +blissup_ln = line( nan, nan, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( nan, nan, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( nan, nan, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( nan, nan, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +vend_ln = line( nan, nan, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + +if show_plot == 1 || theid == legend_plot_id + if theid == legend_plot_id + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[12.40 10.60 1.9 0.95 ] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7 ] ); % (1,4tl) + end + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) + end +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +% The default is to align the plot title across whole figure, not the box. +% This is a hack to nudge the title back to the center of the box. +if impl == 'octave' + tpos = get( titl, 'Position' ); + % For some reason, the titles in the graphs in the last column start + % off in a different relative position than the graphs in the other + % columns. Here, we manually account for that. + if mod(theid-1,cols) == 6 + tpos(1) = tpos(1) + -10; + else + tpos(1) = tpos(1) + -40; + end + set( titl, 'Position', tpos ); + set( titl, 'FontSize', fontsize ); +else % impl == 'matlab' + tpos = get( titl, 'Position' ); + tpos(1) = tpos(1) + 90; + set( titl, 'Position', tpos ); +end + +if theid > (rows-1)*cols + %xlab = xlabel( ax1,xaxisname ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, 'm = 6; n = k' ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, 'n = 8; m = k' ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, 'k = 10; m = n' ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, 'm; n = 8, k = 10' ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, 'n; m = 6, k = 10' ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, 'k; m = 6, n = 8' ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, 'm = n = k' ); + end +end + +if mod(theid-1,cols) == 0 + ylab = ylabel( ax1,yaxisname ); +end + +r_val = 0; + diff --git a/test/sup/old/supmt/octave/plot_panel_trxsh.m b/test/sup/old/supmt/octave/plot_panel_trxsh.m new file mode 100644 index 0000000000..b9fac8ff92 --- /dev/null +++ b/test/sup/old/supmt/octave/plot_panel_trxsh.m @@ -0,0 +1,152 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + dirpath, ... + arch_str, ... + vend_str, ... + impl ... + ) + +%cfreq = 1.8; +%dflopspercycle = 32; + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blislpab = '%s/output_%s_%s_blislpab.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +if 1 == 1 + %fig = figure('Position', [100, 100, 2400, 1500]); + fig = figure('Position', [100, 100, 2400, 1200]); + orient( fig, 'portrait' ); + set(gcf,'PaperUnits', 'inches'); + if impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); + else % impl == 'octave' % octave 4.x + set(gcf,'PaperSize', [12 22.0]); + set(gcf,'PaperPositionMode','auto'); + end + set(gcf,'PaperOrientation','landscape'); +end + + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + opsupname = strtrim( opsupname ); + opname = strtrim( opname ); + + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Construct filenames for the data files from templates. + file_blissup = sprintf( filetemp_blissup, dirpath, thr_str, opsupname ); + file_blislpab = sprintf( filetemp_blislpab, dirpath, thr_str, opsupname ); + file_eigen = sprintf( filetemp_eigen, dirpath, thr_str, opsupname ); + file_open = sprintf( filetemp_open, dirpath, thr_str, opsupname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opsupname ); + + % Load the data files. + %str = sprintf( ' Loading %s', file_blissup ); disp(str); + run( file_blissup ) + run( file_blislpab ) + run( file_eigen ) + run( file_open ) + run( file_vend ) + + % Construct variable names for the variables in the data files. + var_blissup = sprintf( vartemp, thr_str, opname, 'blissup' ); + var_blislpab = sprintf( vartemp, thr_str, opname, 'blislpab' ); + var_eigen = sprintf( vartemp, thr_str, opname, 'eigen' ); + var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data_blissup = eval( var_blissup ); % e.g. data_st_dgemm_blissup( :, : ); + data_blislpab = eval( var_blislpab ); % e.g. data_st_dgemm_blislpab( :, : ); + data_eigen = eval( var_eigen ); % e.g. data_st_dgemm_eigen( :, : ); + data_open = eval( var_open ); % e.g. data_st_dgemm_openblas( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_dgemm_vendor( :, : ); + + %str = sprintf( ' Reading %s', var_blissup ); disp(str); + %str = sprintf( ' Reading %s', var_blislpab ); disp(str); + %str = sprintf( ' Reading %s', var_eigen ); disp(str); + %str = sprintf( ' Reading %s', var_open ); disp(str); + %str = sprintf( ' Reading %s', var_bfeo ); disp(str); + %str = sprintf( ' Reading %s', var_xsmm ); disp(str); + %str = sprintf( ' Reading %s', var_vend ); disp(str); + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + if 1 == 1 + plot_l3sup_perf( opsupname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_vend, vend_str, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl ); + + clear data_mt_*gemm_*; + clear data_blissup; + clear data_blislpab; + clear data_eigen; + clear data_open; + clear data_vend; + + end + +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +%print(gcf, 'gemm_md','-fillpage','-dpdf'); +%print(gcf, outfile,'-bestfit','-dpdf'); +if impl == 'octave' + print(gcf, outfile); +else % if impl == 'matlab' + print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/old/supmt/octave/runthese.m b/test/sup/old/supmt/octave/runthese.m new file mode 100644 index 0000000000..e11f8b173e --- /dev/null +++ b/test/sup/old/supmt/octave/runthese.m @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + +% kabylake +plot_panel_trxsh(3.80,16,4,'mt','d','rrr',[ 6 8 10 ],'../../sup/results/kabylake/20200302/mnkt100000_mt4','kbl','MKL','octave'); close; clear all; + +% haswell +plot_panel_trxsh(3.1,16,12,'mt','d','rrr',[ 6 8 10 ],'../../sup/results/haswell/20200302/mnkt100000_mt12','has','MKL','octave'); close; clear all; + +% epyc +plot_panel_trxsh(2.55,8,32,'mt','d','rrr',[ 6 8 10 ],'../../sup/results/epyc/20200302/mnkt100000_mt32','epyc','MKL','octave'); close; clear all; diff --git a/test/sup/old/supmt/runme.sh b/test/sup/old/supmt/runme.sh new file mode 100755 index 0000000000..911fbbaa43 --- /dev/null +++ b/test/sup/old/supmt/runme.sh @@ -0,0 +1,204 @@ +#!/bin/bash + +# File pefixes. +exec_root="test" +out_root="output" + +sys="blis" +#sys="lonestar5" +#sys="ul252" +#sys="ul264" + +if [ ${sys} = "blis" ]; then + + export GOMP_CPU_AFFINITY="0-3" + nt=4 + +elif [ ${sys} = "lonestar5" ]; then + + export GOMP_CPU_AFFINITY="0-23" + nt=12 + +elif [ ${sys} = "ul252" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0-51" + nt=26 + +elif [ ${sys} = "ul264" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0-63" + nt=32 + +fi + +# Delay between test cases. +delay=0.02 + +# Threadedness to test. +#threads="st mt" +threads="st mt" + +# Datatypes to test. +#dts="d s" +dts="d" + +# Operations to test. +ops="gemm" + +# Transpose combintions to test. +trans="nn nt tn tt" + +# Storage combinations to test. +#stors="rrr rrc rcr rcc crr crc ccr ccc" +stors="rrr ccc" + +# Problem shapes to test. +shapes="sll lsl lls lss sls ssl lll" + +# FGVZ: figure out how to probe what's in the directory and +# execute everything that's there? +sms="6" +sns="8" +sks="10" + +# Implementations to test. +impls="vendor blissup blislpab openblas eigen" +#impls="vendor" +#impls="blissup" +#impls="blislpab" +#impls="openblas" +#impls="eigen" + +# Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can +# restore the value. +GOMP_CPU_AFFINITYsave=${GOMP_CPU_AFFINITY} + +# Example: test_dgemm_nn_rrc_m6npkp_blissup_st.x + +for th in ${threads}; do + + for dt in ${dts}; do + + for op in ${ops}; do + + for tr in ${trans}; do + + for st in ${stors}; do + + for sh in ${shapes}; do + + for sm in ${sms}; do + + for sn in ${sns}; do + + for sk in ${sks}; do + + for im in ${impls}; do + + if [ "${th}" = "mt" ]; then + + # Specify the multithreading depending on which + # implementation is about to be tested. + if [ "${im:0:4}" = "blis" ]; then + unset OMP_NUM_THREADS + export BLIS_NUM_THREADS=${nt} + elif [ "${im}" = "openblas" ]; then + unset OMP_NUM_THREADS + export OPENBLAS_NUM_THREADS=${nt} + elif [ "${im}" = "eigen" ]; then + export OMP_NUM_THREADS=${nt} + elif [ "${im}" = "vendor" ]; then + unset OMP_NUM_THREADS + export MKL_NUM_THREADS=${nt} + fi + export nt_use=${nt} + + else # if [ "${th}" = "st" ]; + + # Use single-threaded execution. + export OMP_NUM_THREADS=1 + export BLIS_NUM_THREADS=1 + export OPENBLAS_NUM_THREADS=1 + export MKL_NUM_THREADS=1 + export nt_use=1 + fi + + # Multithreaded OpenBLAS seems to have a problem + # running properly if GOMP_CPU_AFFINITY is set. + # So we temporarily unset it here if we are about + # to execute OpenBLAS, but otherwise restore it. + if [ ${im} = "openblas" ]; then + unset GOMP_CPU_AFFINITY + else + export GOMP_CPU_AFFINITY="${GOMP_CPU_AFFINITYsave}" + fi + + # Limit execution of non-BLIS implementations to + # rrr/ccc storage cases. + if [ "${im:0:4}" != "blis" ] && \ + [ "${st}" != "rrr" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Further limit execution of libxsmm to + # ccc storage cases. + if [ "${im:0:7}" = "libxsmm" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Extract the shape chars for m, n, k. + chm=${sh:0:1} + chn=${sh:1:1} + chk=${sh:2:1} + + # Construct the shape substring (e.g. m6npkp) + shstr="" + + if [ ${chm} = "s" ]; then + shstr="${shstr}m${sm}" + else + shstr="${shstr}mp" + fi + + if [ ${chn} = "s" ]; then + shstr="${shstr}n${sn}" + else + shstr="${shstr}np" + fi + + if [ ${chk} = "s" ]; then + shstr="${shstr}k${sk}" + else + shstr="${shstr}kp" + fi + + # Ex: test_dgemm_nn_rrc_m6npkp_blissup_st.x + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${tr}_${st}_${shstr}_${im}_${th}.x" + + # Construct the name of the output file. + out_file="${out_root}_${th}_${dt}${op}_${tr}_${st}_${shstr}_${im}.m" + + echo "Running (nt = ${nt_use}) ./${exec_name} > ${out_file}" + + # Run executable. + ./${exec_name} > ${out_file} + + sleep ${delay} + + done + done + done + done + done + done + done + done + done +done + diff --git a/test/sup/old/supmt/test_gemm.c b/test/sup/old/supmt/test_gemm.c new file mode 100644 index 0000000000..23cc564400 --- /dev/null +++ b/test/sup/old/supmt/test_gemm.c @@ -0,0 +1,597 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + //#include + using namespace Eigen; +#else + #include "blis.h" +#endif + +//#define PRINT + +int main( int argc, char** argv ) +{ + rntm_t rntm_g; + + bli_init(); + + // Copy the global rntm_t object so that we can use it later when disabling + // sup. Starting with a copy of the global rntm_t is actually necessary; + // if we start off with a locally-initialized rntm_t, it will not contain + // the ways of parallelism that were conveyed via environment variables, + // which is necessary when running this driver with multiple BLIS threads. + bli_rntm_init_from_global( &rntm_g ); + +#ifndef ERROR_CHECK + bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); +#endif + + + dim_t n_trials = N_TRIALS; + + num_t dt = DT; + +#if 1 + dim_t p_begin = P_BEGIN; + dim_t p_max = P_MAX; + dim_t p_inc = P_INC; +#else + dim_t p_begin = 4; + dim_t p_max = 40; + dim_t p_inc = 4; +#endif + +#if 1 + dim_t m_input = M_DIM; + dim_t n_input = N_DIM; + dim_t k_input = K_DIM; +#else + p_begin = p_inc = 32; + dim_t m_input = 6; + dim_t n_input = -1; + dim_t k_input = -1; +#endif + +#if 1 + trans_t transa = TRANSA; + trans_t transb = TRANSB; +#else + trans_t transa = BLIS_NO_TRANSPOSE; + trans_t transb = BLIS_NO_TRANSPOSE; +#endif + +#if 1 + stor3_t sc = STOR3; +#else + stor3_t sc = BLIS_RRR; +#endif + + + inc_t rs_c, cs_c; + inc_t rs_a, cs_a; + inc_t rs_b, cs_b; + + if ( sc == BLIS_RRR ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = -1; } + else if ( sc == BLIS_RRC ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = 0; } + else if ( sc == BLIS_RCR ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = -1; } + else if ( sc == BLIS_RCC ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = 0; } + else if ( sc == BLIS_CRR ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = -1; } + else if ( sc == BLIS_CRC ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = 0; } + else if ( sc == BLIS_CCR ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = -1; } + else if ( sc == BLIS_CCC ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = 0; } + else { bli_abort(); } + + f77_int cbla_storage; + + if ( sc == BLIS_RRR ) cbla_storage = CblasRowMajor; + else if ( sc == BLIS_CCC ) cbla_storage = CblasColMajor; + else cbla_storage = -1; + + ( void )cbla_storage; + + + char dt_ch; + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + f77_char f77_transa; + f77_char f77_transb; + char transal, transbl; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + transal = tolower( f77_transa ); + transbl = tolower( f77_transb ); + + f77_int cbla_transa = ( transal == 'n' ? CblasNoTrans : CblasTrans ); + f77_int cbla_transb = ( transbl == 'n' ? CblasNoTrans : CblasTrans ); + + ( void )cbla_transa; + ( void )cbla_transb; + + dim_t p; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) + { + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, n, rs_c, cs_c, &c ); + bli_obj_create( dt, m, n, rs_c, cs_c, &c_save ); + + if ( bli_does_notrans( transa ) ) + bli_obj_create( dt, m, k, rs_a, cs_a, &a ); + else + bli_obj_create( dt, k, m, rs_a, cs_a, &a ); + + if ( bli_does_notrans( transb ) ) + bli_obj_create( dt, k, n, rs_b, cs_b, &b ); + else + bli_obj_create( dt, n, k, rs_b, cs_b, &b ); + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (1.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + const int os_a = ( bli_obj_is_col_stored( &a ) ? bli_obj_col_stride( &a ) + : bli_obj_row_stride( &a ) ); + const int os_b = ( bli_obj_is_col_stored( &b ) ? bli_obj_col_stride( &b ) + : bli_obj_row_stride( &b ) ); + const int os_c = ( bli_obj_is_col_stored( &c ) ? bli_obj_col_stride( &c ) + : bli_obj_row_stride( &c ) ); + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #if defined(IS_FLOAT) + #elif defined (IS_DOUBLE) + #ifdef A_STOR_R + typedef Matrix MatrixXd_A; + #else + typedef Matrix MatrixXd_A; + #endif + #ifdef B_STOR_R + typedef Matrix MatrixXd_B; + #else + typedef Matrix MatrixXd_B; + #endif + #ifdef C_STOR_R + typedef Matrix MatrixXd_C; + #else + typedef Matrix MatrixXd_C; + #endif + + #ifdef A_NOTRANS // A is not transposed + Map > A( ( double* )ap, m, k, stride_a ); + #else // A is transposed + Map > A( ( double* )ap, k, m, stride_a ); + #endif + + #ifdef B_NOTRANS // B is not transposed + Map > B( ( double* )bp, k, n, stride_b ); + #else // B is transposed + Map > B( ( double* )bp, n, k, stride_b ); + #endif + + Map > C( ( double* )cp, m, n, stride_c ); + #endif +#endif + + + double dtime_save = DBL_MAX; + + for ( dim_t r = 0; r < n_trials; ++r ) + { + bli_copym( &c_save, &c ); + + + double dtime = bli_clock(); + + +#ifdef EIGEN + + #ifdef A_NOTRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A * B; + #else // B_TRANS + C.noalias() += alpha_r * A * B.transpose(); + #endif + #else // A_TRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A.transpose() * B; + #else // B_TRANS + C.noalias() += alpha_r * A.transpose() * B.transpose(); + #endif + #endif + +#endif +#ifdef BLIS + #ifdef SUP + // Allow sup. + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + #else + // NOTE: We can't use the static initializer and must instead + // initialize the rntm_t with the copy from the global rntm_t we + // made at the beginning of main(). Please see the comment there + // for more info on why BLIS_RNTM_INITIALIZER doesn't work here. + //rntm_t rntm = BLIS_RNTM_INITIALIZER; + rntm_t rntm = rntm_g; + + // Disable sup and use the expert interface. + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha, + &a, + &b, + &beta, + &c, NULL, &rntm ); + #endif +#endif +#ifdef BLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_sgemm( &f77_transa, + #else + sgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_dgemm( &f77_transa, + #else + dgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_cgemm( &f77_transa, + #else + cgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_zgemm( &f77_transa, + #else + zgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif +#ifdef CBLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + cblas_sgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + cblas_dgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cblas_cgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + cblas_zgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + double gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )n, + ( unsigned long )k, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/sup/old/supst/Makefile b/test/sup/old/supst/Makefile new file mode 100644 index 0000000000..6ab97b06f7 --- /dev/null +++ b/test/sup/old/supst/Makefile @@ -0,0 +1,496 @@ +#!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all all-st all-mt \ + blis blis-st blis-mt \ + clean cleanx + + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + + +# +# --- BLAS and LAPACK implementations ------------------------------------------ +# + +# BLIS library and header path. This is simply wherever it was installed. +#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib +#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis + +# BLIS library. +#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a + +# BLAS library path(s). This is where the BLAS libraries reside. +HOME_LIB_PATH := $(HOME)/flame/lib +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 + +# netlib BLAS +NETLIB_LIB := $(HOME_LIB_PATH)/libblas.a + +# OpenBLAS +OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a +OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a + +# BLASFEO +BLASFEO_LIB := $(HOME_LIB_PATH)/libblasfeo.a + +# libxsmm +LIBXSMM_LIB := $(HOME_LIB_PATH)/libxsmm.a -ldl \ + $(NETLIB_LIB) -lgfortran + +# ATLAS +ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ + $(HOME_LIB_PATH)/libatlas.a + +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + +# MKL +MKL_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_sequential \ + -lpthread -lm -ldl +MKLP_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_gnu_thread \ + -lpthread -lm -ldl -fopenmp + #-L$(ICC_LIB_PATH) \ + #-lgomp + +VENDOR_LIB := $(MKL_LIB) +VENDORP_LIB := $(MKLP_LIB) + + +# +# --- Problem size definitions ------------------------------------------------- +# + +# Single core +PS_BEGIN_3L := 2 +PS_MAX_3L := 400 +PS_INC_3L := 2 + +PS_BEGIN_2L := 4 +PS_MAX_2L := 800 +PS_INC_2L := 4 + +PS_BEGIN_1L := 32 +PS_MAX_1L := 6400 +PS_INC_1L := 32 + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c))) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-frame-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +# Use the "framework" CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS. +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) + +# Single or multithreaded string. +STR_ST := -DTHR_STR=\"st\" +STR_MT := -DTHR_STR=\"mt\" + +# Number of trials per problem size. +N_TRIALS := -DN_TRIALS=3 + +# Problem size specification. +PDEF_ST_1L := -DP_BEGIN=$(PS_BEGIN_1L) -DP_MAX=$(PS_MAX_1L) -DP_INC=$(PS_INC_1L) +PDEF_ST_2L := -DP_BEGIN=$(PS_BEGIN_2L) -DP_MAX=$(PS_MAX_2L) -DP_INC=$(PS_INC_2L) +PDEF_ST_3L := -DP_BEGIN=$(PS_BEGIN_3L) -DP_MAX=$(PS_MAX_3L) -DP_INC=$(PS_INC_3L) + +ifeq ($(E),1) +ERRCHK := -DERROR_CHECK +else +ERRCHK := -DNO_ERROR_CHECK +endif + +# Enumerate possible datatypes and computation precisions. +#dts := s d c z +DTS := d + +TRANS := n_n \ + n_t \ + t_n \ + t_t + +# While BLIS supports all combinations of row and column storage for matrices +# C, A, and B, the alternatives mostly only support CBLAS APIs, which inherently +# support only "all row-storage" or "all column-storage". Thus, we disable the +# building of those other drivers so that compilation/linking completes sooner. +#STORS := r_r_r \ +# r_r_c \ +# r_c_r \ +# r_c_c \ +# c_r_r \ +# c_r_c \ +# c_c_r \ +# c_c_c +STORS := r_r_r \ + c_c_c + + +SHAPES := l_l_s \ + l_s_l \ + s_l_l \ + s_s_l \ + s_l_s \ + l_s_s \ + l_l_l + +SMS := 6 +SNS := 8 +SKS := 4 + + +# +# --- Function definitions ----------------------------------------------------- +# + +# A function to strip the underscores from a list of strings. +stripu = $(subst _,,$(1)) + +# Various functions that help us construct the datatype combinations and then +# extract the needed datatype strings and C preprocessor define flags. +get-1of2 = $(word 1,$(subst _, ,$(1))) +get-2of2 = $(word 2,$(subst _, ,$(1))) + +get-1of3 = $(word 1,$(subst _, ,$(1))) +get-2of3 = $(word 2,$(subst _, ,$(1))) +get-3of3 = $(word 3,$(subst _, ,$(1))) + +# A function to return the correct PDEFS_ST variable given the shape string. +get-pdefs = $(strip $(subst l_l_l,$(PDEF_ST_3L), \ + $(subst l_l_s,$(PDEF_ST_2L), \ + $(subst l_s_l,$(PDEF_ST_2L), \ + $(subst s_l_l,$(PDEF_ST_2L), \ + $(subst s_s_l,$(PDEF_ST_1L), \ + $(subst s_l_s,$(PDEF_ST_1L), \ + $(subst l_s_s,$(PDEF_ST_1L),$(1))))))))) + +# Datatype defs. +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) + +# Transpose defs. +get-tra-defs-a = $(strip $(subst n,-DTRANSA=BLIS_NO_TRANSPOSE -DA_NOTRANS, \ + $(subst t,-DTRANSA=BLIS_TRANSPOSE -DA_TRANS,$(call get-1of2,$(1))))) +get-tra-defs-b = $(strip $(subst n,-DTRANSB=BLIS_NO_TRANSPOSE -DB_NOTRANS, \ + $(subst t,-DTRANSB=BLIS_TRANSPOSE -DB_TRANS,$(call get-2of2,$(1))))) +get-tra-defs = $(call get-tra-defs-a,$(1)) $(call get-tra-defs-b,$(1)) + +# Storage defs. +get-sto-uch-a = $(strip $(subst r,R, \ + $(subst c,C,$(call get-1of3,$(1))))) +get-sto-uch-b = $(strip $(subst r,R, \ + $(subst c,C,$(call get-2of3,$(1))))) +get-sto-uch-c = $(strip $(subst r,R, \ + $(subst c,C,$(call get-3of3,$(1))))) +get-sto-defs = $(strip \ + -DSTOR3=BLIS_$(call get-sto-uch-a,$(1))$(call get-sto-uch-b,$(1))$(call get-sto-uch-c,$(1)) \ + -DA_STOR_$(call get-sto-uch-a,$(1)) \ + -DB_STOR_$(call get-sto-uch-b,$(1)) \ + -DC_STOR_$(call get-sto-uch-c,$(1))) + +# Dimension defs. +get-shape-defs-cm = $(if $(findstring l,$(1)),-DM_DIM=-1,-DM_DIM=$(2)) +get-shape-defs-cn = $(if $(findstring l,$(1)),-DN_DIM=-1,-DN_DIM=$(2)) +get-shape-defs-ck = $(if $(findstring l,$(1)),-DK_DIM=-1,-DK_DIM=$(2)) +get-shape-defs-m = $(call get-shape-defs-cm,$(call get-1of3,$(1)),$(2)) +get-shape-defs-n = $(call get-shape-defs-cn,$(call get-2of3,$(1)),$(2)) +get-shape-defs-k = $(call get-shape-defs-ck,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-defs = $(strip $(call get-shape-defs-m,$(1),$(2)) \ + $(call get-shape-defs-n,$(1),$(3)) \ + $(call get-shape-defs-k,$(1),$(4))) + +#$(error l_l_s 6 8 4 = $(call get-shape-defs,l_l_s,6,8,4)) + +# Shape-dimension string. +get-shape-str-ch = $(if $(findstring l,$(1)),p,$(2)) +get-shape-str-m = $(call get-shape-str-ch,$(call get-1of3,$(1)),$(2)) +get-shape-str-n = $(call get-shape-str-ch,$(call get-2of3,$(1)),$(2)) +get-shape-str-k = $(call get-shape-str-ch,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-dim-str = m$(call get-shape-str-m,$(1),$(2))n$(call get-shape-str-n,$(1),$(3))k$(call get-shape-str-k,$(1),$(4)) + +# Implementation defs. +# Define a function to return the appropriate -DSTR= and -D[BLIS|BLAS] flags. +get-imp-defs = $(strip $(subst blissup,-DSTR=\"$(1)\" -DBLIS -DSUP, \ + $(subst blislpab,-DSTR=\"$(1)\" -DBLIS, \ + $(subst eigen,-DSTR=\"$(1)\" -DEIGEN, \ + $(subst openblas,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst blasfeo,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst libxsmm,-DSTR=\"$(1)\" -DBLAS -DXSMM, \ + $(subst vendor,-DSTR=\"$(1)\" -DCBLAS,$(1))))))))) + +TRANS0 = $(call stripu,$(TRANS)) +STORS0 = $(call stripu,$(STORS)) + +# Limit BLAS and Eigen to only using all row-stored, or all column-stored matrices. +# Also, limit libxsmm to using all column-stored matrices since it does not offer +# CBLAS interfaces. +BSTORS0 = rrr ccc +ESTORS0 = rrr ccc +XSTORS0 = ccc + + +# +# --- Object and binary file definitons ---------------------------------------- +# + +get-st-objs = $(foreach dt,$(1),$(foreach tr,$(2),$(foreach st,$(3),$(foreach sh,$(4),$(foreach sm,$(5),$(foreach sn,$(6),$(foreach sk,$(7),test_$(dt)gemm_$(tr)_$(st)_$(call get-shape-dim-str,$(sh),$(sm),$(sn),$(sk))_$(8)_st.o))))))) + +# Build a list of object files and binaries for each single-threaded +# implementation using the get-st-objs() function defined above. +BLISSUP_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),blissup) +BLISSUP_ST_BINS := $(patsubst %.o,%.x,$(BLISSUP_ST_OBJS)) + +BLISLPAB_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),blislpab) +BLISLPAB_ST_BINS := $(patsubst %.o,%.x,$(BLISLPAB_ST_OBJS)) + +EIGEN_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(ESTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),eigen) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) + +OPENBLAS_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),openblas) +OPENBLAS_ST_BINS := $(patsubst %.o,%.x,$(OPENBLAS_ST_OBJS)) + +BLASFEO_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),blasfeo) +BLASFEO_ST_BINS := $(patsubst %.o,%.x,$(BLASFEO_ST_OBJS)) + +LIBXSMM_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(XSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),libxsmm) +LIBXSMM_ST_BINS := $(patsubst %.o,%.x,$(LIBXSMM_ST_OBJS)) + +VENDOR_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),vendor) +VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) + +#$(error "objs = $(EIGEN_ST_BINS)" ) + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLISSUP_ST_OBJS) \ + $(BLISLPAB_ST_OBJS) \ + $(EIGEN_ST_OBJS) \ + $(OPENBLAS_ST_OBJS) \ + $(BLASFEO_ST_OBJS) \ + $(LIBXSMM_ST_OBJS) \ + $(VENDOR_ST_OBJS) + + +# +# --- Targets/rules ------------------------------------------------------------ +# + +all: st + +blissup: blissup-st +blislpab: blislpab-st +eigen: eigen-st +openblas: openblas-st +blasfeo: blasfeo-st +libxsmm: libxsmm-st +vendor: vendor-st + +st: blissup-st blislpab-st \ + eigen-st openblas-st blasfeo-st libxsmm-st vendor-st +blis: blissup-st blislpab-st + +blissup-st: $(BLISSUP_ST_BINS) +blislpab-st: $(BLISLPAB_ST_BINS) +eigen-st: $(EIGEN_ST_BINS) +openblas-st: $(OPENBLAS_ST_BINS) +blasfeo-st: $(BLASFEO_ST_BINS) +libxsmm-st: $(LIBXSMM_ST_BINS) +vendor-st: $(VENDOR_ST_BINS) + + +# --Object file rules -- + +# Define the implementations for which we will instantiate compilation rules. +BIMPLS := blissup blislpab openblas blasfeo libxsmm vendor +EIMPLS := eigen + +# 1 2 3 4 567 8 +# test_dgemm_nn_rrr_mpn6kp_blissup_st.x + +# Define the function that will be used to instantiate compilation rules +# for the various implementations. +define make-st-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_st.o: test_gemm.c Makefile + $(CC) $(CFLAGS) $(ERRCHK) $(N_TRIALS) $(call get-pdefs,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each BLIS/BLAS/CBLAS +# implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS), \ +$(foreach sn,$(SNS), \ +$(foreach sk,$(SKS), \ +$(foreach impl,$(BIMPLS), \ +$(eval $(call make-st-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + +# Define the function that will be used to instantiate compilation rules +# for the various implementations. +define make-eigst-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_st.o: test_gemm.c Makefile + $(CXX) $(CXXFLAGS_ST) $(ERRCHK) $(N_TRIALS) $(call get-pdefs,$(4)) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each Eigen implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS), \ +$(foreach sn,$(SNS), \ +$(foreach sk,$(SKS), \ +$(foreach impl,$(EIMPLS), \ +$(eval $(call make-eigst-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + + +# -- Executable file rules -- + +# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS +# on the link command line in case BLIS was configured with the BLAS +# compatibility layer. This prevents BLIS from inadvertently getting called +# for the BLAS routines we are trying to test with. + +test_%_blissup_st.x: test_%_blissup_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blislpab_st.x: test_%_blislpab_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_eigen_st.x: test_%_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_openblas_st.x: test_%_openblas_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blasfeo_st.x: test_%_blasfeo_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(BLASFEO_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_libxsmm_st.x: test_%_libxsmm_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBXSMM_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_vendor_st.x: test_%_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.x *.o + diff --git a/test/sup/old/supst/octave/gen_opsupnames.m b/test/sup/old/supst/octave/gen_opsupnames.m new file mode 100644 index 0000000000..b70c8a12a7 --- /dev/null +++ b/test/sup/old/supst/octave/gen_opsupnames.m @@ -0,0 +1,36 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp', op, stor, smallm ); + opsupnames( i+1, : ) = sprintf( '%s_%s_mpn%dkp', op, stor, smalln ); + opsupnames( i+2, : ) = sprintf( '%s_%s_mpnpk%d', op, stor, smallk ); + opsupnames( i+3, : ) = sprintf( '%s_%s_mpn%dk%d', op, stor, smalln, smallk ); + opsupnames( i+4, : ) = sprintf( '%s_%s_m%dnpk%d', op, stor, smallm, smallk ); + opsupnames( i+5, : ) = sprintf( '%s_%s_m%dn%dkp', op, stor, smallm, smalln ); + opsupnames( i+6, : ) = sprintf( '%s_%s_mpnpkp', op, stor ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + diff --git a/test/sup/old/supst/octave/plot_l3sup_perf.m b/test/sup/old/supst/octave/plot_l3sup_perf.m new file mode 100644 index 0000000000..ebc5d30004 --- /dev/null +++ b/test/sup/old/supst/octave/plot_l3sup_perf.m @@ -0,0 +1,321 @@ +function r_val = plot_l3sup_perf( opname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_bfeo, ... + data_xsmm, ... + data_vend, vend_str, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl ) + +%if ... %mod(theid-1,cols) == 2 || ... +% ... %mod(theid-1,cols) == 3 || ... +% ... %mod(theid-1,cols) == 4 || ... +% 0 == 1 ... %theid >= 19 +% show_plot = 0; +%else + show_plot = 1; +%end + +%legend_plot_id = 11; +legend_plot_id = 2*cols + 1*5; + +if 1 + ax1 = subplot( rows, cols, theid ); + hold( ax1, 'on' ); +end + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blislpab = 'k'; lines_blislpab = ':'; markr_blislpab = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_bfeo = 'c'; lines_bfeo = '-'; markr_bfeo = 'o'; +color_xsmm = 'g'; lines_xsmm = '-'; markr_xsmm = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_legend = sprintf( 'BLIS sup' ); +blislpab_legend = sprintf( 'BLIS conv' ); +eigen_legend = sprintf( 'Eigen' ); +open_legend = sprintf( 'OpenBLAS' ); +bfeo_legend = sprintf( 'BLASFEO' ); +xsmm_legend = sprintf( 'libxsmm' ); +%vend_legend = sprintf( 'MKL' ); +%vend_legend = sprintf( 'ARMPL' ); +vend_legend = vend_str; + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + + +%flopscol = 4; +flopscol = size( data_blissup, 2 ); +msize = 5; +if 1 + fontsize = 12; +else + fontsize = 16; +end +linesize = 0.5; +legend_loc = 'southeast'; + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that +% we only use half the data points for the m = n = k column of graphs. +if mod(theid-1,cols) == 6 + np = size( data_blissup, 1 ) / 2; +else + np = size( data_blissup, 1 ); +end + +has_xsmm = 1; +if data_xsmm( 1, flopscol ) == 0.0 + has_xsmm = 0; +end + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +if show_plot == 1 +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( x_axis( 1:np, 1 ), data_blislpab( 1:np, flopscol ) / nth, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +bfeo_ln = line( x_axis( 1:np, 1 ), data_bfeo( 1:np, flopscol ) / nth, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +if has_xsmm == 1 +xsmm_ln = line( x_axis( 1:np, 1 ), data_xsmm( 1:np, flopscol ) / nth, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +else +xsmm_ln = line( nan, nan, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +end +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +elseif theid == legend_plot_id +blissup_ln = line( nan, nan, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( nan, nan, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( nan, nan, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( nan, nan, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +bfeo_ln = line( nan, nan, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +xsmm_ln = line( nan, nan, ... + 'Color',color_xsmm, 'LineStyle',lines_xsmm, ... + 'LineWidth',linesize ); +vend_ln = line( nan, nan, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + +if show_plot == 1 || theid == legend_plot_id + if nth == 1 && theid == legend_plot_id + if has_xsmm == 1 + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + bfeo_ln ... + xsmm_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + bfeo_legend, ... + xsmm_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[15.40 4.75 1.9 1.20] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-3 ); + set( leg,'Position',[18.20 10.20 1.15 0.7 ] ); % (1,4tl) + end + else + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + bfeo_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + bfeo_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[15.40 7.65 1.9 1.10] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7] ); % (1,4tl) + end + end + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) + elseif nth > 1 && theid == legend_plot_id + end +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +% The default is to align the plot title across whole figure, not the box. +% This is a hack to nudge the title back to the center of the box. +if impl == 'octave' + tpos = get( titl, 'Position' ); + % For some reason, the titles in the graphs in the last column start + % off in a different relative position than the graphs in the other + % columns. Here, we manually account for that. + if mod(theid-1,cols) == 6 + tpos(1) = tpos(1) + -10; + else + tpos(1) = tpos(1) + -40; + end + set( titl, 'Position', tpos ); + set( titl, 'FontSize', fontsize ); +else % impl == 'matlab' + tpos = get( titl, 'Position' ); + tpos(1) = tpos(1) + 90; + set( titl, 'Position', tpos ); +end + +if theid > (rows-1)*cols + %xlab = xlabel( ax1,xaxisname ); + %tpos = get( xlab, 'Position' ) + %tpos(2) = tpos(2) + 10; + %set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, 'm = 6; n = k' ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, 'n = 8; m = k' ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, 'k = 4; m = n' ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, 'm; n = 8, k = 4' ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, 'n; m = 6, k = 4' ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, 'k; m = 6, n = 8' ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, 'm = n = k' ); + end +end + +if mod(theid-1,cols) == 0 + ylab = ylabel( ax1,yaxisname ); +end + +r_val = 0; + diff --git a/test/sup/old/supst/octave/plot_panel_trxsh.m b/test/sup/old/supst/octave/plot_panel_trxsh.m new file mode 100644 index 0000000000..8ba709257b --- /dev/null +++ b/test/sup/old/supst/octave/plot_panel_trxsh.m @@ -0,0 +1,172 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + dirpath, ... + arch_str, ... + vend_str, ... + impl ... + ) + +%cfreq = 1.8; +%dflopspercycle = 32; + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blislpab = '%s/output_%s_%s_blislpab.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_bfeo = '%s/output_%s_%s_blasfeo.m'; +filetemp_xsmm = '%s/output_%s_%s_libxsmm.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +if 1 == 1 + %fig = figure('Position', [100, 100, 2400, 1500]); + fig = figure('Position', [100, 100, 2400, 1200]); + orient( fig, 'portrait' ); + set(gcf,'PaperUnits', 'inches'); + if impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); + else % impl == 'octave' % octave 4.x + set(gcf,'PaperSize', [12 22.0]); + set(gcf,'PaperPositionMode','auto'); + end + set(gcf,'PaperOrientation','landscape'); +end + + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Construct filenames for the data files from templates. + file_blissup = sprintf( filetemp_blissup, dirpath, thr_str, opsupname ); + file_blislpab = sprintf( filetemp_blislpab, dirpath, thr_str, opsupname ); + file_eigen = sprintf( filetemp_eigen, dirpath, thr_str, opsupname ); + file_open = sprintf( filetemp_open, dirpath, thr_str, opsupname ); + file_bfeo = sprintf( filetemp_bfeo, dirpath, thr_str, opsupname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opsupname ); + + % Load the data files. + %str = sprintf( ' Loading %s', file_blissup ); disp(str); + run( file_blissup ) + run( file_blislpab ) + run( file_eigen ) + run( file_open ) + run( file_bfeo ) + run( file_vend ) + + % Construct variable names for the variables in the data files. + var_blissup = sprintf( vartemp, thr_str, opname, 'blissup' ); + var_blislpab = sprintf( vartemp, thr_str, opname, 'blislpab' ); + var_eigen = sprintf( vartemp, thr_str, opname, 'eigen' ); + var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); + var_bfeo = sprintf( vartemp, thr_str, opname, 'blasfeo' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data_blissup = eval( var_blissup ); % e.g. data_st_dgemm_blissup( :, : ); + data_blislpab = eval( var_blislpab ); % e.g. data_st_dgemm_blislpab( :, : ); + data_eigen = eval( var_eigen ); % e.g. data_st_dgemm_eigen( :, : ); + data_open = eval( var_open ); % e.g. data_st_dgemm_openblas( :, : ); + data_bfeo = eval( var_bfeo ); % e.g. data_st_dgemm_blasfeo( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_dgemm_vendor( :, : ); + + if stor_str == 'ccc' + % Only read xsmm data for the column storage case, since that's the + % only format that libxsmm supports. + file_xsmm = sprintf( filetemp_xsmm, dirpath, thr_str, opsupname ); + run( file_xsmm ) + var_xsmm = sprintf( vartemp, thr_str, opname, 'libxsmm' ); + data_xsmm = eval( var_xsmm ); % e.g. data_st_dgemm_libxsmm( :, : ); + else + % Set the data variable to zeros using the same dimensions as the other + % variables. + data_xsmm = zeros( size( data_blissup, 1 ), ... + size( data_blissup, 2 ) ); + end + %str = sprintf( ' Reading %s', var_blissup ); disp(str); + %str = sprintf( ' Reading %s', var_blislpab ); disp(str); + %str = sprintf( ' Reading %s', var_eigen ); disp(str); + %str = sprintf( ' Reading %s', var_open ); disp(str); + %str = sprintf( ' Reading %s', var_bfeo ); disp(str); + %str = sprintf( ' Reading %s', var_xsmm ); disp(str); + %str = sprintf( ' Reading %s', var_vend ); disp(str); + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + if 1 == 1 + plot_l3sup_perf( opsupname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_bfeo, ... + data_xsmm, ... + data_vend, vend_str, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl ); + + clear data_st_*gemm_*; + clear data_blissup; + clear data_blislpab; + clear data_eigen; + clear data_open; + clear data_bfeo; + clear data_xsmm; + clear data_vend; + + end + +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +%print(gcf, 'gemm_md','-fillpage','-dpdf'); +%print(gcf, outfile,'-bestfit','-dpdf'); +if impl == 'octave' + print(gcf, outfile); +else % if impl == 'matlab' + print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/old/supst/octave/runthese.m b/test/sup/old/supst/octave/runthese.m new file mode 100644 index 0000000000..30e3865d1e --- /dev/null +++ b/test/sup/old/supst/octave/runthese.m @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + +% kabylake +plot_panel_trxsh(3.80,16,1,'st','d','rrr',[ 6 8 4 ],'../results/kabylake/20200302/mnkt100000_st','kbl','MKL','octave'); close; clear all; + +% haswell +plot_panel_trxsh(3.5,16,1,'st','d','rrr',[ 6 8 4 ],'../results/haswell/20200302/mnkt100000_st','has','MKL','octave'); close; clear all; + +% epyc +plot_panel_trxsh(3.00, 8,1,'st','d','rrr',[ 6 8 4 ],'../results/epyc/20200302/mnkt100000_st','epyc','MKL','octave'); close; clear all; diff --git a/test/sup/old/supst/runme.sh b/test/sup/old/supst/runme.sh new file mode 100755 index 0000000000..48dacfa3a6 --- /dev/null +++ b/test/sup/old/supst/runme.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +# File pefixes. +exec_root="test" +out_root="output" + +# Placeholder until we add multithreading. +nt=1 + +# Delay between test cases. +delay=0.02 + +# Threadedness to test. +threads="st" + +# Datatypes to test. +#dts="d s" +dts="d" + +# Operations to test. +ops="gemm" + +# Transpose combintions to test. +trans="nn nt tn tt" + +# Storage combinations to test. +#stors="rrr rrc rcr rcc crr crc ccr ccc" +stors="rrr ccc" + +# Problem shapes to test. +shapes="sll lsl lls lss sls ssl lll" + +# FGVZ: figure out how to probe what's in the directory and +# execute everything that's there? +sms="6" +sns="8" +sks="4" + +# Implementations to test. +impls="vendor blissup blislpab openblas eigen libxsmm blasfeo" +#impls="vendor" +#impls="blissup" +#impls="blislpab" +#impls="openblas" +#impls="eigen" +#impls="libxsmm" +#impls="blasfeo" + +# Example: test_dgemm_nn_rrc_m6npkp_blissup_st.x + +for th in ${threads}; do + + for dt in ${dts}; do + + for op in ${ops}; do + + for tr in ${trans}; do + + for st in ${stors}; do + + for sh in ${shapes}; do + + for sm in ${sms}; do + + for sn in ${sns}; do + + for sk in ${sks}; do + + for im in ${impls}; do + + # Limit execution of non-BLIS implementations to + # rrr/ccc storage cases. + if [ "${im:0:4}" != "blis" ] && \ + [ "${st}" != "rrr" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Further limit execution of libxsmm to + # ccc storage cases. + if [ "${im:0:7}" = "libxsmm" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Extract the shape chars for m, n, k. + chm=${sh:0:1} + chn=${sh:1:1} + chk=${sh:2:1} + + # Construct the shape substring (e.g. m6npkp) + shstr="" + + if [ ${chm} = "s" ]; then + shstr="${shstr}m${sm}" + else + shstr="${shstr}mp" + fi + + if [ ${chn} = "s" ]; then + shstr="${shstr}n${sn}" + else + shstr="${shstr}np" + fi + + if [ ${chk} = "s" ]; then + shstr="${shstr}k${sk}" + else + shstr="${shstr}kp" + fi + + # Ex: test_dgemm_nn_rrc_m6npkp_blissup_st.x + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${tr}_${st}_${shstr}_${im}_${th}.x" + + # Construct the name of the output file. + out_file="${out_root}_${th}_${dt}${op}_${tr}_${st}_${shstr}_${im}.m" + + echo "Running (nt = ${nt}) ./${exec_name} > ${out_file}" + + # Run executable. + ./${exec_name} > ${out_file} + + sleep ${delay} + + done + done + done + done + done + done + done + done + done +done + diff --git a/test/sup/old/supst/test_gemm.c b/test/sup/old/supst/test_gemm.c new file mode 100644 index 0000000000..7f611b554d --- /dev/null +++ b/test/sup/old/supst/test_gemm.c @@ -0,0 +1,583 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + //#include + using namespace Eigen; +#else + #include "blis.h" +#endif + +//#define PRINT + +int main( int argc, char** argv ) +{ + + bli_init(); + +#ifndef ERROR_CHECK + bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); +#endif + + + dim_t n_trials = N_TRIALS; + + num_t dt = DT; + +#if 1 + dim_t p_begin = P_BEGIN; + dim_t p_max = P_MAX; + dim_t p_inc = P_INC; +#else + dim_t p_begin = 4; + dim_t p_max = 40; + dim_t p_inc = 4; +#endif + +#if 1 + dim_t m_input = M_DIM; + dim_t n_input = N_DIM; + dim_t k_input = K_DIM; +#else + p_begin = p_inc = 32; + dim_t m_input = 6; + dim_t n_input = -1; + dim_t k_input = -1; +#endif + +#if 1 + trans_t transa = TRANSA; + trans_t transb = TRANSB; +#else + trans_t transa = BLIS_NO_TRANSPOSE; + trans_t transb = BLIS_NO_TRANSPOSE; +#endif + +#if 1 + stor3_t sc = STOR3; +#else + stor3_t sc = BLIS_RRR; +#endif + + + inc_t rs_c, cs_c; + inc_t rs_a, cs_a; + inc_t rs_b, cs_b; + + if ( sc == BLIS_RRR ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = -1; } + else if ( sc == BLIS_RRC ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = 0; } + else if ( sc == BLIS_RCR ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = -1; } + else if ( sc == BLIS_RCC ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = 0; } + else if ( sc == BLIS_CRR ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = -1; } + else if ( sc == BLIS_CRC ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = 0; } + else if ( sc == BLIS_CCR ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = -1; } + else if ( sc == BLIS_CCC ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = 0; } + else { bli_abort(); } + + f77_int cbla_storage; + + if ( sc == BLIS_RRR ) cbla_storage = CblasRowMajor; + else if ( sc == BLIS_CCC ) cbla_storage = CblasColMajor; + else cbla_storage = -1; + + ( void )cbla_storage; + + + char dt_ch; + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + f77_char f77_transa; + f77_char f77_transb; + char transal, transbl; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + transal = tolower( f77_transa ); + transbl = tolower( f77_transb ); + + f77_int cbla_transa = ( transal == 'n' ? CblasNoTrans : CblasTrans ); + f77_int cbla_transb = ( transbl == 'n' ? CblasNoTrans : CblasTrans ); + + ( void )cbla_transa; + ( void )cbla_transb; + + dim_t p; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) + { + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, n, rs_c, cs_c, &c ); + bli_obj_create( dt, m, n, rs_c, cs_c, &c_save ); + + if ( bli_does_notrans( transa ) ) + bli_obj_create( dt, m, k, rs_a, cs_a, &a ); + else + bli_obj_create( dt, k, m, rs_a, cs_a, &a ); + + if ( bli_does_notrans( transb ) ) + bli_obj_create( dt, k, n, rs_b, cs_b, &b ); + else + bli_obj_create( dt, n, k, rs_b, cs_b, &b ); + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (1.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + const int os_a = ( bli_obj_is_col_stored( &a ) ? bli_obj_col_stride( &a ) + : bli_obj_row_stride( &a ) ); + const int os_b = ( bli_obj_is_col_stored( &b ) ? bli_obj_col_stride( &b ) + : bli_obj_row_stride( &b ) ); + const int os_c = ( bli_obj_is_col_stored( &c ) ? bli_obj_col_stride( &c ) + : bli_obj_row_stride( &c ) ); + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #if defined(IS_FLOAT) + #elif defined (IS_DOUBLE) + #ifdef A_STOR_R + typedef Matrix MatrixXd_A; + #else + typedef Matrix MatrixXd_A; + #endif + #ifdef B_STOR_R + typedef Matrix MatrixXd_B; + #else + typedef Matrix MatrixXd_B; + #endif + #ifdef C_STOR_R + typedef Matrix MatrixXd_C; + #else + typedef Matrix MatrixXd_C; + #endif + + #ifdef A_NOTRANS // A is not transposed + Map > A( ( double* )ap, m, k, stride_a ); + #else // A is transposed + Map > A( ( double* )ap, k, m, stride_a ); + #endif + + #ifdef B_NOTRANS // B is not transposed + Map > B( ( double* )bp, k, n, stride_b ); + #else // B is transposed + Map > B( ( double* )bp, n, k, stride_b ); + #endif + + Map > C( ( double* )cp, m, n, stride_c ); + #endif +#endif + + + double dtime_save = DBL_MAX; + + for ( dim_t r = 0; r < n_trials; ++r ) + { + bli_copym( &c_save, &c ); + + + double dtime = bli_clock(); + + +#ifdef EIGEN + + #ifdef A_NOTRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A * B; + #else // B_TRANS + C.noalias() += alpha_r * A * B.transpose(); + #endif + #else // A_TRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A.transpose() * B; + #else // B_TRANS + C.noalias() += alpha_r * A.transpose() * B.transpose(); + #endif + #endif + +#endif +#ifdef BLIS + #ifdef SUP + // Allow sup. + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + #else + // Disable sup and use the expert interface. + rntm_t rntm = BLIS_RNTM_INITIALIZER; + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha, + &a, + &b, + &beta, + &c, NULL, &rntm ); + #endif +#endif +#ifdef BLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_sgemm( &f77_transa, + #else + sgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_dgemm( &f77_transa, + #else + dgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_cgemm( &f77_transa, + #else + cgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_zgemm( &f77_transa, + #else + zgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif +#ifdef CBLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + cblas_sgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + cblas_dgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cblas_cgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + cblas_zgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + double gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )n, + ( unsigned long )k, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/sup/runme.sh b/test/sup/runme.sh new file mode 100755 index 0000000000..d8fd07dcd3 --- /dev/null +++ b/test/sup/runme.sh @@ -0,0 +1,349 @@ +#!/bin/bash + +# File pefixes. +exec_root="test" +out_root="output" + +#sys="blis" +#sys="lonestar5" +#sys="ul252" +#sys="ul264" +sys="ul2128" + +if [ ${sys} = "blis" ]; then + + export GOMP_CPU_AFFINITY="0-3" + + numactl="" + nt=4 + +elif [ ${sys} = "lonestar5" ]; then + + export GOMP_CPU_AFFINITY="0-23" + + numactl="" + nt=12 + +elif [ ${sys} = "ul252" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0-51" + + numactl="numactl --interleave=all" + nt=26 + +elif [ ${sys} = "ul264" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0-63" + + numactl="numactl --interleave=all" + nt=32 + +elif [ ${sys} = "ul2128" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0-127" + + numactl="numactl --interleave=all" + nt=32 + +fi + +# Delay between test cases. +delay=0.02 + +# Threadedness to test. +#threads="st mt" +threads="st" + +# Datatypes to test. +dts="s d" + +# Operations to test. +ops="gemm" + +# Transpose combintions to test. +trans="nn nt tn tt" + +# Storage combinations to test. +# NOTE: mixed storage cases are not yet implemented in test_gemm.c. +#stors="rrr rrc rcr rcc crr crc ccr ccc" +stors="rrr ccc" + +# Problem shapes to test. +shapes="sll lsl lls lss sls ssl lll" + +# Small problem dimensions to use. +# FGVZ: figure out how to probe what's in the directory and +# execute everything that's there? +# st, single real +sms_st_s="6" +sns_st_s="16" +sks_st_s="4" +# st, double real +sms_st_d="6" +sns_st_d="8" +sks_st_d="4" +# mt, single real +sms_mt_s="6" +sns_mt_s="16" +sks_mt_s="10" +# mt, double real +sms_mt_d="6" +sns_mt_d="8" +sks_mt_d="10" + +# Leading dimensions to use (small or large). +# When a leading dimension is large, it is constant and set to the largest +# problem size that will be run. +#ldims="s l" +ldims="s" + +# Packing combinations for blissup. The first char encodes the packing status +# of matrix A and the second char encodes the packing status of matrix B. +# NOTE: This string must always contain 'uu' if other implementations are also +# being tested at the same time. +#pcombos="uu up pu pp" +pcombos="uu" + +# Implementations to test. +impls="vendor blissup blisconv openblas eigen blasfeo libxsmm" +#impls="vendor blissup blisconv openblas eigen" +#impls="vendor" +#impls="blissup" +#impls="blisconv" +#impls="openblas" +#impls="eigen" +#impls="blasfeo" + +# Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can +# restore the value. +GOMP_CPU_AFFINITYsave=${GOMP_CPU_AFFINITY} + +# Example: test_dgemm_nn_rrc_m6npkp_blissup_st.x + +for th in ${threads}; do + + for dt in ${dts}; do + + # Choose the small m, n, and k values based on the threadedness and + # datatype currently being executed. + if [ ${th} = "st" ]; then + if [ ${dt} = "s" ]; then + sms=${sms_st_s} + sns=${sns_st_s} + sks=${sks_st_s} + elif [ ${dt} = "d" ]; then + sms=${sms_st_d} + sns=${sns_st_d} + sks=${sks_st_d} + else + exit 1 + fi + elif [ ${th} = "mt" ]; then + if [ ${dt} = "s" ]; then + sms=${sms_mt_s} + sns=${sns_mt_s} + sks=${sks_mt_s} + elif [ ${dt} = "d" ]; then + sms=${sms_mt_d} + sns=${sns_mt_d} + sks=${sks_mt_d} + else + exit 1 + fi + fi + + for op in ${ops}; do + + for tr in ${trans}; do + + for st in ${stors}; do + + for sh in ${shapes}; do + + for sm in ${sms}; do + + for sn in ${sns}; do + + for sk in ${sks}; do + + for ld in ${ldims}; do + + for im in ${impls}; do + + for pc in ${pcombos}; do + + if [ "${th}" = "mt" ]; then + + # Prohibit attempts to run blasfeo or libxsmm as + # multithreaded. + if [ "${im}" = "blasfeo" ] || \ + [ "${im}" = "libxsmm" ]; then + continue; + fi + + # Specify the multithreading depending on which + # implementation is about to be tested. + if [ "${im:0:4}" = "blis" ]; then + unset OMP_NUM_THREADS + export BLIS_NUM_THREADS=${nt} + elif [ "${im}" = "openblas" ]; then + unset OMP_NUM_THREADS + export OPENBLAS_NUM_THREADS=${nt} + elif [ "${im}" = "eigen" ]; then + export OMP_NUM_THREADS=${nt} + elif [ "${im}" = "vendor" ]; then + unset OMP_NUM_THREADS + export MKL_NUM_THREADS=${nt} + fi + export nt_use=${nt} + + else # if [ "${th}" = "st" ]; + + # Use single-threaded execution. + export OMP_NUM_THREADS=1 + export BLIS_NUM_THREADS=1 + export OPENBLAS_NUM_THREADS=1 + export MKL_NUM_THREADS=1 + export nt_use=1 + fi + + # Isolate the individual chars in the current pcombo + # string. + packa=${pc:0:1} + packb=${pc:1:1} + + # For blissup implementations, set the BLIS_PACK_A and + # BLIS_PACK_B environment variables according to the + # chars in the current pcombo string. + if [ "${im:0:7}" = "blissup" ]; then + + # Set BLIS_PACK_A if the pcombo char is 'p'; otherwise + # unset the variable altogether. + if [ ${packa} = "p" ]; then + export BLIS_PACK_A=1 + else + unset BLIS_PACK_A + fi + + # Set BLIS_PACK_B if the pcombo char is 'p'; otherwise + # unset the variable altogether. + if [ ${packb} = "p" ]; then + export BLIS_PACK_B=1 + else + unset BLIS_PACK_B + fi + else + + # Unset the variables for non-blissup implementations, + # just to be paranoid-safe. + unset BLIS_PACK_A + unset BLIS_PACK_B + fi + + # Limit execution of non-blissup implementations to the + # 'uu' packing combination. (Those implementations don't + # use the pcombos string, but since we iterate over its + # words for all implementations, we have to designate one + # of them as a placeholder to allow those implementations + # to execute. The 'uu' string was chosen over the 'pp' + # string because it's more likely that this script will be + # used to run blissup on unpacked matrices, and so the + # sorting for the output files is nicer if the non-blissup + # implementations use the 'uu' string, even if it's more + # likely that those implementations use packing. Think of + # 'uu' as encoding the idea that explicit packing was not + # requested.) + if [ "${im:0:7}" != "blissup" ] && \ + [ "${pc}" != "uu" ]; then + continue; + fi + + # Multithreaded OpenBLAS seems to have a problem + # running properly if GOMP_CPU_AFFINITY is set. + # So we temporarily unset it here if we are about + # to execute OpenBLAS, but otherwise restore it. + if [ ${im} = "openblas" ]; then + unset GOMP_CPU_AFFINITY + else + export GOMP_CPU_AFFINITY="${GOMP_CPU_AFFINITYsave}" + fi + + # Limit execution of non-BLIS implementations to + # rrr/ccc storage cases. + if [ "${im:0:4}" != "blis" ] && \ + [ "${st}" != "rrr" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Further limit execution of libxsmm to + # ccc storage cases. + if [ "${im:0:7}" = "libxsmm" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Extract the shape chars for m, n, k. + chm=${sh:0:1} + chn=${sh:1:1} + chk=${sh:2:1} + + # Construct the shape substring (e.g. m6npkp) + shstr="" + + if [ ${chm} = "s" ]; then + shstr="${shstr}m${sm}" + else + shstr="${shstr}mp" + fi + + if [ ${chn} = "s" ]; then + shstr="${shstr}n${sn}" + else + shstr="${shstr}np" + fi + + if [ ${chk} = "s" ]; then + shstr="${shstr}k${sk}" + else + shstr="${shstr}kp" + fi + + # Construct the ldim substring (e.g. lds or ldl) + ldstr="ld${ld}" + + # Construct the pack substring (e.g. uaub, uapb, paub, or papb) + packstr="${packa}a${packb}b" + + # Ex: test_dgemm_nn_rrc_m6npkp_blissup_st.x + # Ex: test_dgemm_nt_rrr_m6npkp_ldl_blissup_st.x + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${tr}_${st}_${shstr}_${ldstr}_${im}_${th}.x" + + # Construct the name of the output file. + out_file="${out_root}_${th}_${dt}${op}_${tr}_${st}_${shstr}_${ldstr}_${packstr}_${im}.m" + + echo "Running (nt = ${nt_use}) ${numactl} ./${exec_name} > ${out_file}" + + # Run executable. + ${numactl} ./${exec_name} > ${out_file} + + sleep ${delay} + + done + done + done + done + done + done + done + done + done + done + done +done + diff --git a/test/sup/test_gemm.c b/test/sup/test_gemm.c new file mode 100644 index 0000000000..d191957907 --- /dev/null +++ b/test/sup/test_gemm.c @@ -0,0 +1,680 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + //#include + using namespace Eigen; +#else + #include "blis.h" +#endif + +//#define PRINT + +int main( int argc, char** argv ) +{ + rntm_t rntm_g; + + bli_init(); + + // Copy the global rntm_t object so that we can use it later when disabling + // sup. Starting with a copy of the global rntm_t is actually necessary; + // if we start off with a locally-initialized rntm_t, it will not contain + // the ways of parallelism that were conveyed via environment variables, + // which is necessary when running this driver with multiple BLIS threads. + bli_rntm_init_from_global( &rntm_g ); + +#ifndef ERROR_CHECK + bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); +#endif + + + dim_t n_trials = N_TRIALS; + + num_t dt = DT; + +#if 1 + dim_t p_begin = P_BEGIN; + dim_t p_max = P_MAX; + dim_t p_inc = P_INC; +#else + dim_t p_begin = 4; + dim_t p_max = 40; + dim_t p_inc = 4; +#endif + +#if 1 + dim_t m_input = M_DIM; + dim_t n_input = N_DIM; + dim_t k_input = K_DIM; +#else + p_begin = p_inc = 32; + dim_t m_input = 6; + dim_t n_input = -1; + dim_t k_input = -1; +#endif + +#if 1 + trans_t transa = TRANSA; + trans_t transb = TRANSB; +#else + trans_t transa = BLIS_NO_TRANSPOSE; + trans_t transb = BLIS_NO_TRANSPOSE; +#endif + +#if 1 + stor3_t sc = STOR3; +#else + stor3_t sc = BLIS_RRR; +#endif + + + inc_t rs_c, cs_c; + inc_t rs_a, cs_a; + inc_t rs_b, cs_b; + + f77_int cbla_storage; + + if ( sc == BLIS_RRR ) cbla_storage = CblasRowMajor; + else if ( sc == BLIS_CCC ) cbla_storage = CblasColMajor; + else cbla_storage = -1; + + ( void )cbla_storage; + + + char dt_ch; + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + f77_char f77_transa; + f77_char f77_transb; + char transal, transbl; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + transal = tolower( f77_transa ); + transbl = tolower( f77_transb ); + + f77_int cbla_transa = ( transal == 'n' ? CblasNoTrans : CblasTrans ); + f77_int cbla_transb = ( transbl == 'n' ? CblasNoTrans : CblasTrans ); + + ( void )cbla_transa; + ( void )cbla_transb; + + dim_t p; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + // Flush the initialization-inducing statement above so there is + // at least one line of output even if the implementation crashes. + fflush( stdout ); + + + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) + { + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + +#ifdef LDIM_SMALL + + // Setting the row and column strides of a matrix to '0' is a shorthand + // request for column storage; using '-1' is shorthand for row storage. + //else if ( sc == BLIS_RRC ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = 0; } + //else if ( sc == BLIS_RCR ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = -1; } + //else if ( sc == BLIS_RCC ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = 0; } + //else if ( sc == BLIS_CRR ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = -1; } + //else if ( sc == BLIS_CRC ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = 0; } + //else if ( sc == BLIS_CCR ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = -1; } + + if ( sc == BLIS_RRR ) { rs_c = -1; rs_a = -1; rs_b = -1; + cs_c = -1; cs_a = -1; cs_b = -1; } + else if ( sc == BLIS_RRC ) { rs_c = -1; rs_a = -1; rs_b = 0; + cs_c = -1; cs_a = -1; cs_b = 0; } + else if ( sc == BLIS_RCR ) { rs_c = -1; rs_a = 0; rs_b = -1; + cs_c = -1; cs_a = 0; cs_b = -1; } + else if ( sc == BLIS_RCC ) { rs_c = -1; rs_a = 0; rs_b = 0; + cs_c = -1; cs_a = 0; cs_b = 0; } + else if ( sc == BLIS_CRR ) { rs_c = 0; rs_a = -1; rs_b = -1; + cs_c = 0; cs_a = -1; cs_b = -1; } + else if ( sc == BLIS_CRC ) { rs_c = 0; rs_a = -1; rs_b = 0; + cs_c = 0; cs_a = -1; cs_b = 0; } + else if ( sc == BLIS_CCR ) { rs_c = 0; rs_a = 0; rs_b = -1; + cs_c = 0; cs_a = 0; cs_b = -1; } + else if ( sc == BLIS_CCC ) { rs_c = 0; rs_a = 0; rs_b = 0; + cs_c = 0; cs_a = 0; cs_b = 0; } + else { bli_abort(); } + +#else // LDIM_LARGE + + #if 0 + const dim_t m_large = m; + const dim_t n_large = n; + const dim_t k_large = k; + #else + const dim_t m_large = p_max; + const dim_t n_large = p_max; + const dim_t k_large = p_max; + #endif + + if ( sc == BLIS_RRR ) { rs_c = n_large; rs_a = k_large; rs_b = n_large; + cs_c = 1; cs_a = 1; cs_b = 1; } + else if ( sc == BLIS_RRC ) { rs_c = n_large; rs_a = k_large; rs_b = 1; + cs_c = 1; cs_a = 1; cs_b = k_large; } + else if ( sc == BLIS_RCR ) { rs_c = n_large; rs_a = 1; rs_b = n_large; + cs_c = 1; cs_a = m_large; cs_b = 1; } + else if ( sc == BLIS_RCC ) { rs_c = n_large; rs_a = 1; rs_b = 1; + cs_c = 1; cs_a = m_large; cs_b = k_large; } + else if ( sc == BLIS_CRR ) { rs_c = 1; rs_a = k_large; rs_b = n_large; + cs_c = m_large; cs_a = 1; cs_b = 1; } + else if ( sc == BLIS_CRC ) { rs_c = 1; rs_a = k_large; rs_b = 1; + cs_c = m_large; cs_a = 1; cs_b = k_large; } + else if ( sc == BLIS_CCR ) { rs_c = 1; rs_a = 1; rs_b = n_large; + cs_c = m_large; cs_a = m_large; cs_b = 1; } + else if ( sc == BLIS_CCC ) { rs_c = 1; rs_a = 1; rs_b = 1; + cs_c = m_large; cs_a = m_large; cs_b = k_large; } + else { bli_abort(); } +#endif + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, n, rs_c, cs_c, &c ); + bli_obj_create( dt, m, n, rs_c, cs_c, &c_save ); + + if ( bli_does_notrans( transa ) ) + bli_obj_create( dt, m, k, rs_a, cs_a, &a ); + else + bli_obj_create( dt, k, m, cs_a, rs_a, &a ); + + if ( bli_does_notrans( transb ) ) + bli_obj_create( dt, k, n, rs_b, cs_b, &b ); + else + bli_obj_create( dt, n, k, cs_b, rs_b, &b ); + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (1.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + const int os_a = ( bli_obj_is_col_stored( &a ) ? bli_obj_col_stride( &a ) + : bli_obj_row_stride( &a ) ); + const int os_b = ( bli_obj_is_col_stored( &b ) ? bli_obj_col_stride( &b ) + : bli_obj_row_stride( &b ) ); + const int os_c = ( bli_obj_is_col_stored( &c ) ? bli_obj_col_stride( &c ) + : bli_obj_row_stride( &c ) ); + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #if defined(IS_FLOAT) + #ifdef A_STOR_R + typedef Matrix MatrixXs_A; + #else + typedef Matrix MatrixXs_A; + #endif + #ifdef B_STOR_R + typedef Matrix MatrixXs_B; + #else + typedef Matrix MatrixXs_B; + #endif + #ifdef C_STOR_R + typedef Matrix MatrixXs_C; + #else + typedef Matrix MatrixXs_C; + #endif + + #ifdef A_NOTRANS // A is not transposed + Map > A( ( float* )ap, m, k, stride_a ); + #else // A is transposed + Map > A( ( float* )ap, k, m, stride_a ); + #endif + + #ifdef B_NOTRANS // B is not transposed + Map > B( ( float* )bp, k, n, stride_b ); + #else // B is transposed + Map > B( ( float* )bp, n, k, stride_b ); + #endif + + Map > C( ( float* )cp, m, n, stride_c ); + #elif defined (IS_DOUBLE) + #ifdef A_STOR_R + typedef Matrix MatrixXd_A; + #else + typedef Matrix MatrixXd_A; + #endif + #ifdef B_STOR_R + typedef Matrix MatrixXd_B; + #else + typedef Matrix MatrixXd_B; + #endif + #ifdef C_STOR_R + typedef Matrix MatrixXd_C; + #else + typedef Matrix MatrixXd_C; + #endif + + #ifdef A_NOTRANS // A is not transposed + Map > A( ( double* )ap, m, k, stride_a ); + #else // A is transposed + Map > A( ( double* )ap, k, m, stride_a ); + #endif + + #ifdef B_NOTRANS // B is not transposed + Map > B( ( double* )bp, k, n, stride_b ); + #else // B is transposed + Map > B( ( double* )bp, n, k, stride_b ); + #endif + + Map > C( ( double* )cp, m, n, stride_c ); + #endif +#endif + + + double dtime_save = DBL_MAX; + + for ( dim_t r = 0; r < n_trials; ++r ) + { + bli_copym( &c_save, &c ); + + + double dtime = bli_clock(); + + +#ifdef EIGEN + + #ifdef A_NOTRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A * B; + #else // B_TRANS + C.noalias() += alpha_r * A * B.transpose(); + #endif + #else // A_TRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A.transpose() * B; + #else // B_TRANS + C.noalias() += alpha_r * A.transpose() * B.transpose(); + #endif + #endif + +#endif +#ifdef BLIS + #ifdef SUP + // Allow sup. + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + #else + // NOTE: We can't use the static initializer and must instead + // initialize the rntm_t with the copy from the global rntm_t we + // made at the beginning of main(). Please see the comment there + // for more info on why BLIS_RNTM_INITIALIZER doesn't work here. + //rntm_t rntm = BLIS_RNTM_INITIALIZER; + rntm_t rntm = rntm_g; + + // Disable sup and use the expert interface. + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha, + &a, + &b, + &beta, + &c, NULL, &rntm ); + #endif +#endif +#ifdef BLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_sgemm( &f77_transa, + #else + sgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_dgemm( &f77_transa, + #else + dgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_cgemm( &f77_transa, + #else + cgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + #ifdef XSMM + libxsmm_zgemm( &f77_transa, + #else + zgemm_( &f77_transa, + #endif + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif +#ifdef CBLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + cblas_sgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + cblas_dgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cblas_cgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + cblas_zgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + double gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )n, + ( unsigned long )k, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/syrk_diagonal/complex_math.hpp b/test/syrk_diagonal/complex_math.hpp new file mode 100644 index 0000000000..9c68e730aa --- /dev/null +++ b/test/syrk_diagonal/complex_math.hpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "blis.h" + +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +struct is_real : std::integral_constant::value> {}; + +template struct make_complex; + +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; + +template +using make_complex_t = typename make_complex::type; + +template struct make_real; + +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; + +template +using make_real_t = typename make_real::type; + +template +struct make_complex_if : std::conditional,make_real_t> {}; + +template +using make_complex_if_t = typename make_complex_if::type; + +template +struct real_imag_part +{ + real_imag_part& operator=(T) { return *this; } + + operator T() const { return T(); } +}; + +template +std::enable_if_t::type>::value,T&> real(T& x) { return x; } + +template +std::enable_if_t::value,real_imag_part> imag(T x) { return {}; } + +inline float& real(scomplex& x) { return x.real; } + +inline float& imag(scomplex& x) { return x.imag; } + +inline double& real(dcomplex& x) { return x.real; } + +inline double& imag(dcomplex& x) { return x.imag; } + +inline const float& real(const scomplex& x) { return x.real; } + +inline const float& imag(const scomplex& x) { return x.imag; } + +inline const double& real(const dcomplex& x) { return x.real; } + +inline const double& imag(const dcomplex& x) { return x.imag; } + +template +std::enable_if_t::value,T> conj(T x) { return x; } + +template +std::enable_if_t::value,T> conj(const T& x) { return {x.real, -x.imag}; } + +template +struct convert_impl; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x; y.imag = 0; } +}; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x.real; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; } +}; + +template +U convert(T x) +{ + U y; + convert_impl{}(x,y); + return y; +} + +template +auto convert_prec(T x) -> make_complex_if_t::value> +{ + return convert::value>>(x); +} + +#define COMPLEX_MATH_OPS(rtype, ctype) \ +\ +inline bool operator==(rtype x, ctype y) \ +{ \ + return x == y.real && y.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, rtype y) \ +{ \ + return y == x.real && x.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, ctype y) \ +{ \ + return x.real == y.real && \ + x.imag == y.imag; \ + } \ + \ +inline ctype operator-(ctype x) \ +{ \ + return {-x.real, -x.imag}; \ +} \ +\ +inline ctype operator+(rtype x, ctype y) \ +{ \ + return {x+y.real, y.imag}; \ +} \ +\ +inline ctype operator+(ctype x, rtype y) \ +{ \ + return {y+x.real, x.imag}; \ +} \ +\ +inline ctype operator+(ctype x, ctype y) \ +{ \ + return {x.real+y.real, x.imag+y.imag}; \ +} \ +\ +inline ctype operator-(rtype x, ctype y) \ +{ \ + return {x-y.real, -y.imag}; \ +} \ +\ +inline ctype operator-(ctype x, rtype y) \ +{ \ + return {x.real-y, x.imag}; \ +} \ +\ +inline ctype operator-(ctype x, ctype y) \ +{ \ + return {x.real-y.real, x.imag-y.imag}; \ +} \ +\ +inline ctype operator*(rtype x, ctype y) \ +{ \ + return {x*y.real, x*y.imag}; \ +} \ +\ +inline ctype operator*(ctype x, rtype y) \ +{ \ + return {y*x.real, y*x.imag}; \ +} \ +\ +inline ctype operator*(ctype x, ctype y) \ +{ \ + return {x.real*y.real - x.imag*y.imag, \ + x.real*y.imag + x.imag*y.real}; \ +} \ +\ +inline ctype operator/(rtype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {x*yrs/denom, -x*yis/denom}; \ +} \ +\ +inline ctype operator/(ctype x, rtype y) \ +{ \ + return {x.real/y, x.imag/y}; \ +} \ +\ +inline ctype operator/(ctype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {(x.real*yrs + x.imag*yis)/denom, \ + (x.imag*yrs - x.real*yis)/denom}; \ +} \ +\ +inline ctype& operator+=(ctype& x, rtype y) \ +{ \ + x.real += y; \ + return x; \ +} \ +\ +inline ctype& operator+=(ctype& x, ctype y) \ +{ \ + x.real += y.real; x.imag += y.imag; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, rtype y) \ +{ \ + x.real -= y; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, ctype y) \ +{ \ + x.real -= y.real; x.imag -= y.imag; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, rtype y) \ +{ \ + x.real *= y; x.imag *= y; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, ctype y) \ +{ \ + x = x * y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, rtype y) \ +{ \ + x.real /= y; x.imag /= y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, ctype y) \ +{ \ + x = x / y; \ + return x; \ +} + +COMPLEX_MATH_OPS(float, scomplex); +COMPLEX_MATH_OPS(double, dcomplex); + diff --git a/test/syrk_diagonal/syrk_diagonal_example.c b/test/syrk_diagonal/syrk_diagonal_example.c new file mode 100644 index 0000000000..c2bfd8fa19 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example.c @@ -0,0 +1,186 @@ +#include "syrk_diagonal_ref.h" + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +typedef struct packm_diag_params_t +{ + packm_blk_var1_params_t super; + void* d; + inc_t incd; +} packm_diag_params_t; + +/* + * Declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + struc_t struca, \ + diag_t diaga, \ + uplo_t uploa, \ + conj_t conja, \ + pack_t schema, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ + void* restrict kappa, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp, \ + inc_t is_p, \ + cntx_t* cntx, \ + void* params \ + ) \ +{ \ + packm_diag_params_t* params_cast = params; \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = params_cast->d; \ + inc_t incd = params_cast->incd; \ + ctype kappa_cast = *( ctype* )kappa; \ +\ + if ( schema != BLIS_PACKED_ROW_PANELS && \ + schema != BLIS_PACKED_COL_PANELS ) \ + bli_abort(); \ +\ + /* Apply the offset */ \ + d_cast += panel_len_off * incd; \ +\ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ +\ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ +} + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + memset( params, 0, sizeof(*params) ); + + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); + + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + params->super.ukr_fn[i][i] = packm_diag_ukrs[i]; + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); +} + +int main( void ) +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr ); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_example.cxx b/test/syrk_diagonal/syrk_diagonal_example.cxx new file mode 100644 index 0000000000..1c269d5c48 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example.cxx @@ -0,0 +1,220 @@ +#include "syrk_diagonal_ref.h" + +/* + * Forward-declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +template +void packm_diag_ukr + ( + struc_t /*struca*/, + diag_t /*diaga*/, + uplo_t /*uploa*/, + conj_t conja, + pack_t schema, + bool /*invdiag*/, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + dim_t /*panel_dim_off*/, + dim_t panel_len_off, + void* restrict kappa, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp, + inc_t /*is_p*/, + cntx_t* /*cntx*/, + void* params + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_diag_ukr; + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +struct packm_diag_params_t : packm_blk_var1_params_t +{ + void* d; + inc_t incd; + + packm_diag_params_t() {} + + packm_diag_params_t( void* d, inc_t incd ) + : d(d), incd(incd) + { + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + ukr_fn[i][i] = packm_diag_ukrs[i]; + } +}; + +/* + * Selecting a different kernel based on the current architecture is + * currently not possible, but is something we plan to support. + */ +template +void packm_diag_ukr + ( + struc_t /*struca*/, + diag_t /*diaga*/, + uplo_t /*uploa*/, + conj_t conja, + pack_t schema, + bool /*invdiag*/, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + dim_t /*panel_dim_off*/, + dim_t panel_len_off, + void* restrict kappa, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp, + inc_t /*is_p*/, + cntx_t* /*cntx*/, + void* params + ) +{ + auto params_cast = ( packm_diag_params_t* )params; + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + T* restrict d_cast = ( T* )params_cast->d; + auto incd = params_cast->incd; + auto kappa_cast = *( T* )kappa; + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + /* Apply the offset */ + d_cast += panel_len_off * incd; + + if ( conja ) + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + else + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + new (params) packm_diag_params_t + ( + bli_obj_buffer_at_off( d ), + bli_obj_row_stride( d ) + ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + auto m = 10; + auto k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + auto dt = ( num_t )dt_; + auto uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_example2.c b/test/syrk_diagonal/syrk_diagonal_example2.c new file mode 100644 index 0000000000..92371f48b0 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example2.c @@ -0,0 +1,354 @@ +#include "syrk_diagonal_ref.h" + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +typedef struct packm_diag_params_t +{ + void* d; + inc_t incd; +} packm_diag_params_t; + +typedef void (*packm_diag_ukr_vft) + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ); + +/* + * Declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + bool conja, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + void* restrict kappa, \ + void* restrict d, inc_t incd, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp \ + ) \ +{ \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = d; \ + ctype kappa_cast = *( ctype* )kappa; \ +\ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ +\ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ +} + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +void packm_diag + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ +#if 1 + + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + num_t dt = bli_obj_dt( a ); + num_t dt_tar = bli_obj_target_dt( a ); + num_t dt_scalar = bli_obj_scalar_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + +#else + + // Every thread initializes p and determines the size of memory + // block needed (which gets embedded into the otherwise "blank" mem_t + // entry in the control tree node). Return early if no packing is required. + if ( !bli_packm_init( a, p, cntx, rntm, cntl, thread ) ) + return; + + num_t dt = bli_obj_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt, bmult_id_m, cntx ); + + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_obj_padded_length( p ); + dim_t n_p_pad = bli_obj_padded_width( p ); + dim_t n_iter = m_p_pad / bmult_m_def; + + char* p_cast = bli_obj_buffer( p ); + inc_t ps_p = bli_obj_panel_stride( p ); + +#endif + + char* a_cast = bli_obj_buffer_at_off( a ); + inc_t inca = bli_obj_row_stride( a ); + inc_t lda = bli_obj_col_stride( a ); + dim_t panel_len_off = bli_obj_col_off( a ); + conj_t conja = bli_obj_conj_status( a ); + + packm_diag_params_t* params = bli_obj_pack_params( a ); + char* d_cast = params->d; + inc_t incd = params->incd; + + obj_t kappa_local; + char* kappa_cast = bli_packm_scalar( &kappa_local, p ); + + packm_diag_ukr_vft packm_ker_cast = packm_diag_ukrs[ dt ]; + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( dim_t it = 0; it < n_iter; it += 1 ) + { + dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; + + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + packm_ker_cast + ( + conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack + ); + } + } +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); + + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); +} + +int main( void ) +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr ); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_example2.cxx b/test/syrk_diagonal/syrk_diagonal_example2.cxx new file mode 100644 index 0000000000..8312a07ee8 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example2.cxx @@ -0,0 +1,338 @@ +#include "syrk_diagonal_ref.h" + +/* + * Forward-declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +template +void packm_diag_ukr + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_diag_ukr; + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +using packm_diag_ukr_vft = decltype(&packm_diag_ukr); +static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +struct packm_diag_params_t +{ + void* d; + inc_t incd; + + packm_diag_params_t() {} + + packm_diag_params_t( void* d, inc_t incd ) + : d(d), incd(incd) {} +}; + +/* + * Selecting a different kernel based on the current architecture is + * currently not possible, but is something we plan to support. + */ +template +void packm_diag_ukr + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + T* restrict d_cast = ( T* )d; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + else + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); +} + +void packm_diag + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + num_t dt = bli_obj_dt( a ); + num_t dt_tar = bli_obj_target_dt( a ); + num_t dt_scalar = bli_obj_scalar_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + + char* a_cast = (char*)bli_obj_buffer_at_off( a ); + inc_t inca = bli_obj_row_stride( a ); + inc_t lda = bli_obj_col_stride( a ); + dim_t panel_len_off = bli_obj_col_off( a ); + conj_t conja = bli_obj_conj_status( a ); + + auto params = (packm_diag_params_t*)bli_obj_pack_params( a ); + char* d_cast = (char*)params->d; + inc_t incd = params->incd; + + obj_t kappa_local; + char* kappa_cast = (char*)bli_packm_scalar( &kappa_local, p ); + + auto packm_ker_cast = packm_diag_ukrs[ dt ]; + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( dim_t it = 0; it < n_iter; it += 1 ) + { + dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; + + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + packm_ker_cast( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack ); + } + } +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + new (params) packm_diag_params_t + ( + bli_obj_buffer_at_off( d ), + bli_obj_row_stride( d ) + ); + + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + auto m = 10; + auto k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + auto dt = ( num_t )dt_; + auto uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_ref.cxx b/test/syrk_diagonal/syrk_diagonal_ref.cxx new file mode 100644 index 0000000000..1d7c5d96e5 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_ref.cxx @@ -0,0 +1,102 @@ +#include "syrk_diagonal_ref.h" +#include "complex_math.hpp" + +typedef void (*syrk_diag_ref_vft) + ( + uplo_t uplo, + dim_t m, + dim_t k, + void* alpha, + void* a, inc_t rs_a, inc_t cs_a, + void* d, inc_t incd, + void* beta, + void* c, inc_t rs_c, inc_t cs_c + ); + +template +void syrk_diag_ref + ( + uplo_t uplo, + dim_t m, + dim_t k, + void* alpha, + void* a, inc_t rs_a, inc_t cs_a, + void* d, inc_t incd, + void* beta, + void* c, inc_t rs_c, inc_t cs_c + ) +{ + auto alpha_cast = *( T* )alpha; + auto beta_cast = *( T* )beta; + auto a_cast = ( T* )a; + auto d_cast = ( T* )d; + auto c_cast = ( T* )c; + + for ( dim_t i = 0; i < m; i++ ) + { + dim_t j_min = uplo == BLIS_UPPER ? i : 0; + dim_t j_max = uplo == BLIS_UPPER ? m : i+1; + + for ( dim_t j = j_min; j < j_max; j++ ) + { + auto ada = convert(0.0); + + for ( dim_t p = 0; p < k; p++ ) + { + ada += a_cast[ i*rs_a + p*cs_a ] * + d_cast[ p*incd ] * + a_cast[ j*rs_a + p*cs_a ]; + } + + if ( beta_cast == convert(0.0) ) + { + c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada; + } + else + { + c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada + + beta_cast * c_cast[ i*rs_c + j*cs_c ]; + } + } + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &syrk_diag_ref; + +INSERT_GENTFUNC_BASIC0(syrk_diag_ref); + +static syrk_diag_ref_vft GENARRAY( syrk_diag_ref_impl, syrk_diag_ref ); + +void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + num_t dt = bli_obj_dt( a ); + + dim_t m = bli_obj_length_after_trans( a ); + dim_t k = bli_obj_width_after_trans( a ); + + inc_t rs_a = bli_obj_row_stride( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + inc_t incd = bli_obj_row_stride( d ); + + if ( bli_obj_has_trans( a ) ) + bli_swap_incs( &rs_a, &cs_a ); + + if ( bli_obj_has_trans( c ) ) + bli_swap_incs( &rs_c, &cs_c ); + + syrk_diag_ref_impl[ dt ] + ( + bli_obj_uplo( c ), + m, k, + bli_obj_buffer_for_1x1( dt, alpha ), + bli_obj_buffer_at_off( a ), rs_a, cs_a, + bli_obj_buffer_at_off( d ), incd, + bli_obj_buffer_for_1x1( dt, beta ), + bli_obj_buffer_at_off( c ), rs_c, cs_c + ); +} + diff --git a/test/syrk_diagonal/syrk_diagonal_ref.h b/test/syrk_diagonal/syrk_diagonal_ref.h new file mode 100644 index 0000000000..a6864caec8 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_ref.h @@ -0,0 +1,8 @@ +#include "blis.h" + +#ifdef __cplusplus +#include "complex_math.hpp" +extern "C" +#endif +void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ); + diff --git a/test/tensor_contraction/complex_math.hpp b/test/tensor_contraction/complex_math.hpp new file mode 100644 index 0000000000..9c68e730aa --- /dev/null +++ b/test/tensor_contraction/complex_math.hpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "blis.h" + +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +struct is_real : std::integral_constant::value> {}; + +template struct make_complex; + +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; + +template +using make_complex_t = typename make_complex::type; + +template struct make_real; + +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; + +template +using make_real_t = typename make_real::type; + +template +struct make_complex_if : std::conditional,make_real_t> {}; + +template +using make_complex_if_t = typename make_complex_if::type; + +template +struct real_imag_part +{ + real_imag_part& operator=(T) { return *this; } + + operator T() const { return T(); } +}; + +template +std::enable_if_t::type>::value,T&> real(T& x) { return x; } + +template +std::enable_if_t::value,real_imag_part> imag(T x) { return {}; } + +inline float& real(scomplex& x) { return x.real; } + +inline float& imag(scomplex& x) { return x.imag; } + +inline double& real(dcomplex& x) { return x.real; } + +inline double& imag(dcomplex& x) { return x.imag; } + +inline const float& real(const scomplex& x) { return x.real; } + +inline const float& imag(const scomplex& x) { return x.imag; } + +inline const double& real(const dcomplex& x) { return x.real; } + +inline const double& imag(const dcomplex& x) { return x.imag; } + +template +std::enable_if_t::value,T> conj(T x) { return x; } + +template +std::enable_if_t::value,T> conj(const T& x) { return {x.real, -x.imag}; } + +template +struct convert_impl; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x; y.imag = 0; } +}; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x.real; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; } +}; + +template +U convert(T x) +{ + U y; + convert_impl{}(x,y); + return y; +} + +template +auto convert_prec(T x) -> make_complex_if_t::value> +{ + return convert::value>>(x); +} + +#define COMPLEX_MATH_OPS(rtype, ctype) \ +\ +inline bool operator==(rtype x, ctype y) \ +{ \ + return x == y.real && y.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, rtype y) \ +{ \ + return y == x.real && x.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, ctype y) \ +{ \ + return x.real == y.real && \ + x.imag == y.imag; \ + } \ + \ +inline ctype operator-(ctype x) \ +{ \ + return {-x.real, -x.imag}; \ +} \ +\ +inline ctype operator+(rtype x, ctype y) \ +{ \ + return {x+y.real, y.imag}; \ +} \ +\ +inline ctype operator+(ctype x, rtype y) \ +{ \ + return {y+x.real, x.imag}; \ +} \ +\ +inline ctype operator+(ctype x, ctype y) \ +{ \ + return {x.real+y.real, x.imag+y.imag}; \ +} \ +\ +inline ctype operator-(rtype x, ctype y) \ +{ \ + return {x-y.real, -y.imag}; \ +} \ +\ +inline ctype operator-(ctype x, rtype y) \ +{ \ + return {x.real-y, x.imag}; \ +} \ +\ +inline ctype operator-(ctype x, ctype y) \ +{ \ + return {x.real-y.real, x.imag-y.imag}; \ +} \ +\ +inline ctype operator*(rtype x, ctype y) \ +{ \ + return {x*y.real, x*y.imag}; \ +} \ +\ +inline ctype operator*(ctype x, rtype y) \ +{ \ + return {y*x.real, y*x.imag}; \ +} \ +\ +inline ctype operator*(ctype x, ctype y) \ +{ \ + return {x.real*y.real - x.imag*y.imag, \ + x.real*y.imag + x.imag*y.real}; \ +} \ +\ +inline ctype operator/(rtype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {x*yrs/denom, -x*yis/denom}; \ +} \ +\ +inline ctype operator/(ctype x, rtype y) \ +{ \ + return {x.real/y, x.imag/y}; \ +} \ +\ +inline ctype operator/(ctype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {(x.real*yrs + x.imag*yis)/denom, \ + (x.imag*yrs - x.real*yis)/denom}; \ +} \ +\ +inline ctype& operator+=(ctype& x, rtype y) \ +{ \ + x.real += y; \ + return x; \ +} \ +\ +inline ctype& operator+=(ctype& x, ctype y) \ +{ \ + x.real += y.real; x.imag += y.imag; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, rtype y) \ +{ \ + x.real -= y; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, ctype y) \ +{ \ + x.real -= y.real; x.imag -= y.imag; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, rtype y) \ +{ \ + x.real *= y; x.imag *= y; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, ctype y) \ +{ \ + x = x * y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, rtype y) \ +{ \ + x.real /= y; x.imag /= y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, ctype y) \ +{ \ + x = x / y; \ + return x; \ +} + +COMPLEX_MATH_OPS(float, scomplex); +COMPLEX_MATH_OPS(double, dcomplex); + diff --git a/test/tensor_contraction/tcontract_example.cxx b/test/tensor_contraction/tcontract_example.cxx new file mode 100644 index 0000000000..0b935c54d4 --- /dev/null +++ b/test/tensor_contraction/tcontract_example.cxx @@ -0,0 +1,988 @@ + +#include "tcontract_ref.hpp" + +#include +#include + +static constexpr dim_t BS_K = 8; + +struct packm_tensor_params_t +{ + gint_t ndim_m, ndim_n; + const dim_t *len_m, *len_n; + const inc_t *stride_m, *stride_n; + + packm_tensor_params_t() {} + + packm_tensor_params_t( gint_t ndim_m, const dim_t* len_m, const inc_t* stride_m, + gint_t ndim_n, const dim_t* len_n, const inc_t* stride_n ) + : ndim_m(ndim_m), ndim_n(ndim_n), + len_m(len_m), len_n(len_n), + stride_m(stride_m), stride_n(stride_n) {} +}; + +using gemm_tensor_params_t = packm_tensor_params_t; + +template +void packm_ckx_nb + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* kappa, + void* a, inc_t inca, inc_t* bsa, inc_t* scata, + void* p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K ) + { + auto lda = *bsa; + auto panel_len_j = std::min( panel_len-j0, BS_K ); + + if ( lda ) + { + T* restrict aj = a_cast + *scata; + + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * conj( aj[ i*inca + j*lda ] ); + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for ( auto j = 0; j < panel_len_j; j++) + { + for ( auto i = 0; i < panel_dim; i++) + p_cast[ i ] = kappa_cast * conj( a_cast[ i*inca + scata[j] ] ); + + for ( auto i = panel_dim; i < panel_dim_max; i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + } + } + else + { + for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K ) + { + auto lda = *bsa; + auto panel_len_j = std::min( panel_len-j0, BS_K ); + + if ( lda ) + { + T* restrict aj = a_cast + *scata; + + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * aj[ i*inca + j*lda ]; + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * a_cast[ i*inca + scata[j] ]; + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + } + } + + for ( auto j = panel_len; j < panel_len_max; j++) + { + for ( auto i = 0; i < panel_dim_max; i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } +} + +template +void packm_ckx_ss + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* kappa, + void* a, inc_t* inca, inc_t* scata, + void* p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for (dim_t j = 0;j < panel_len;j++) + { + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i ] = kappa_cast * conj( a_cast[ inca[i] + scata[j] ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for (dim_t j = 0;j < panel_len;j++) + { + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i ] = kappa_cast * a_cast[ inca[i] + scata[j] ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + { + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_ckx_nb; + +INSERT_GENTFUNC_BASIC0(packm_ckx_nb); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_ckx_ss; + +INSERT_GENTFUNC_BASIC0(packm_ckx_ss); + +static decltype(&packm_ckx_nb) GENARRAY( packm_ckx_nb_ukrs, packm_ckx_nb ); +static decltype(&packm_ckx_ss) GENARRAY( packm_ckx_ss_ukrs, packm_ckx_ss ); + +static void fill_scatter + ( + gint_t ndim, + const dim_t* restrict len, + const inc_t* restrict stride, + dim_t BS, + inc_t off, + dim_t size, + inc_t* restrict scat, + inc_t* restrict bs + ) +{ + if ( size == 0 ) return; + + if ( ndim == 0 ) + { + *scat = 0; + *bs = 0; + return; + } + + if ( ndim == 1 ) + { + auto l = *len; + auto s = *stride; + for ( auto i = 0; i < l; i++ ) + { + scat[i] = i*s; + bs[i] = s; + } + } + + dim_t tot_len = 1; + for ( auto i = 0; i < ndim; i++ ) + tot_len *= len[i]; + + assert(off >= 0); + assert(size >= 0); + assert(off+size <= tot_len); + + auto len0 = len[0]; + auto stride0 = stride[0]; + auto off0 = off % len0; + auto off1 = off / len0; + auto size1 = ( size + off0 + len0 - 1) / len0; + + inc_t pos1 = 0; + inc_t idx = 0; + for_each( ndim-1, len+1, off1, size1, pos1, stride+1, + [&] + { + auto pos = pos1 + off0 * stride0; + auto len_i = std::min( len0-off0, size-idx ); + for ( auto i = 0; i < len_i; i++ ) + { + scat[idx++] = pos; + pos += stride0; + } + off0 = 0; + }); + assert(idx == size); + + for ( idx = 0; idx < size; idx += BS ) + { + auto len_i = std::min( BS, size-idx ); + auto s = stride0; + + for ( auto i = idx; i < idx+len_i-1; i++) + { + if (scat[i+1]-scat[i] != s) + { + s = 0; + break; + } + } + + bs[idx] = s; + } +} + +void packm_tensor + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + auto dt = bli_obj_dt( a ); + auto dt_tar = bli_obj_target_dt( a ); + auto dt_scalar = bli_obj_scalar_dt( a ); + auto dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + auto bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + auto bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + auto schema = bli_cntl_packm_params_pack_schema( cntl ); + auto bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + auto bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + auto bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + auto m_p = bli_obj_length( p ); + auto n_p = bli_obj_width( p ); + auto m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + auto n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + auto ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + auto n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + auto size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Compute the size of the scatter and block-scatter vectors to the total. + // It is never necessary to add padding for alignment because: + // 1) ps_p is always even + // 2) dt_size is a power of two >= 4 + // 3) the alignment of the scatter vectors is at most 8 + auto scat_size = 2 * (m_p + n_p) * sizeof(inc_t); + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + auto p_cast = (char*)bli_packm_alloc( size_p + scat_size, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + + // Get the addresses of the scatter and block-scatter vectors. These are + // placed directly after the packed matrix buffer. + auto rscat = (inc_t*)(p_cast + size_p); + auto rbs = rscat + m_p; + auto cscat = rbs + m_p; + auto cbs = cscat + n_p; + + auto a_cast = (char*)bli_obj_buffer_at_off( a ); + auto panel_dim_off = bli_obj_row_off( a ); + auto panel_len_off = bli_obj_col_off( a ); + auto conja = bli_obj_conj_status( a ); + + auto params = (packm_tensor_params_t*)bli_obj_pack_params( a ); + auto ndim_m = params->ndim_m; + auto ndim_n = params->ndim_n; + auto len_m = params->len_m; + auto len_n = params->len_n; + auto stride_m = params->stride_m; + auto stride_n = params->stride_n; + + obj_t kappa_local; + auto kappa_cast = (char*)bli_packm_scalar( &kappa_local, p ); + + auto packm_nb_ker = packm_ckx_nb_ukrs[ dt ]; + auto packm_ss_ker = packm_ckx_ss_ukrs[ dt ]; + + a_cast -= ( panel_dim_off * stride_m[0] + + panel_len_off * stride_n[0] ) * dt_size; + + /* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */ + if ( bli_thread_am_ochief( thread ) ) + { + fill_scatter + ( + ndim_m, + len_m, + stride_m, + bmult_m_def, + panel_dim_off, + m_p, + rscat, + rbs + ); + + fill_scatter + ( + ndim_n, + len_n, + stride_n, + BS_K, + panel_len_off, + n_p, + cscat, + cbs + ); + } + + /* Wait for the scatter vectors to be done. */ + bli_thread_barrier( thread ); + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + auto nt = bli_thread_n_way( thread ); + auto tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( auto it = 0; it < n_iter; it += 1 ) + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + auto panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + auto p_begin = p_cast + it*ps_p*dt_size; + auto inca = rbs[ it*bmult_m_def ]; + + if ( inca ) + { + auto a_begin = a_cast + rscat[ it*bmult_m_def ]*dt_size; + + packm_nb_ker( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + a_begin, inca, cbs, cscat, + p_begin, bmult_m_pack ); + } + else + { + auto a_begin = a_cast; + auto rscat_use = rscat + it*bmult_m_def; + + packm_ss_ker( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + a_begin, rscat_use, cscat, + p_begin, bmult_m_pack ); + } + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t* rs_y, inc_t* cs_y \ + ) \ +{ \ + ctype* restrict x_cast = (ctype*)x; \ + ctype b_cast = *(ctype*)b; \ + ctype* restrict y_cast = (ctype*)y; \ +\ + if ( PASTEMAC(ch,eq0)( b_cast ) ) \ + { \ + for ( auto i = 0; i < m; i++ ) \ + for ( auto j = 0; j < n; j++ ) \ + PASTEMAC(ch,copys)( x_cast[ i*rs_x + j*cs_x ], y_cast[ rs_y[i] + cs_y[j] ] ); \ + } \ + else \ + { \ + for ( auto i = 0; i < m; i++ ) \ + for ( auto j = 0; j < n; j++ ) \ + PASTEMAC(ch,xpbys)( x_cast[ i*rs_x + j*cs_x ], b_cast, y_cast[ rs_y[i] + cs_y[j] ] ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0(scatter_mxn); + +static decltype(&bli_sscatter_mxn) GENARRAY(scatter_mxn, scatter_mxn); + +void gemm_tensor + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + auto dt = bli_obj_dt( c ); + auto dt_size = bli_dt_size( dt ); + + auto m = bli_obj_length( c ); + auto n = bli_obj_width( c ); + auto k = bli_obj_width( a ); + + auto a_cast = (char*)bli_obj_buffer_at_off( a ); + auto pd_a = bli_obj_panel_dim( a ); + auto ps_a = bli_obj_panel_stride( a ); + + auto b_cast = (char*)bli_obj_buffer_at_off( b ); + auto pd_b = bli_obj_panel_dim( b ); + auto ps_b = bli_obj_panel_stride( b ); + + auto c_cast = (char*)bli_obj_buffer_at_off( c ); + auto rs_c0 = bli_obj_row_stride( c ); + auto cs_c0 = bli_obj_col_stride( c ); + auto off_m = bli_obj_row_off( c ); + auto off_n = bli_obj_col_off( c ); + + auto params = (gemm_tensor_params_t*)bli_obj_ker_params( c ); + auto ndim_m = params->ndim_m; + auto ndim_n = params->ndim_n; + auto len_m = params->len_m; + auto len_n = params->len_n; + auto stride_m = params->stride_m; + auto stride_n = params->stride_n; + + if ( rs_c0 != stride_m[0] || cs_c0 != stride_n[0] ) + { + std::swap( ndim_m, ndim_n ); + std::swap( len_m, len_n ); + std::swap( stride_m, stride_n ); + } + + /* If any dimension is zero, return immediately. */ + if ( bli_zero_dim3( m, n, k ) ) return; + + c_cast -= ( off_m * stride_m[0] + + off_n * stride_n[0] ) * dt_size; + + // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a; + obj_t scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + // NOTE: We know that scalar_b is of type dt due to the above code + // that casts the scalars of A and B to dt via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt + // due to the casting in the implementation of bli_obj_scalar_attach(). + auto alpha_cast = (char*)bli_obj_internal_scalar_buffer( &scalar_b ); + auto beta_cast = (char*)bli_obj_internal_scalar_buffer( c ); + + /* Alias some constants to simpler names. */ + auto MR = pd_a; + auto NR = pd_b; + + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ + auto gemm_ukr = (gemm_ukr_vft)bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ + char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + auto col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); + auto rs_ct = ( col_pref ? 1 : NR ); + auto cs_ct = ( col_pref ? MR : 1 ); + auto zero = (char*)bli_obj_buffer_for_const( dt, &BLIS_ZERO ); + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + auto scat_size = 2 * (m + n) * sizeof(inc_t); + auto rscat_c = (inc_t*)bli_packm_alloc_ex( scat_size, BLIS_BUFFER_FOR_GEN_USE, rntm, cntl, thread ); + auto rbs_c = rscat_c + m; + auto cscat_c = rbs_c + m; + auto cbs_c = cscat_c + n; + + /* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */ + if ( bli_thread_am_ochief( thread ) ) + { + fill_scatter + ( + ndim_m, + len_m, + stride_m, + MR, + off_m, + m, + rscat_c, + rbs_c + ); + + fill_scatter + ( + ndim_n, + len_n, + stride_n, + NR, + off_n, + n, + cscat_c, + cbs_c + ); + } + + /* Wait for the scatter vectors to be done. */ + bli_thread_barrier( thread ); + + /* Compute number of primary and leftover components of the m and n + dimensions. */ + auto n_iter = n / NR; + auto n_left = n % NR; + + auto m_iter = m / MR; + auto m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + /* Determine some increments used to step through A, B, and C. */ + auto rstep_a = ps_a * dt_size; + auto cstep_b = ps_b * dt_size; + + /* Save the virtual microkernel address and the params. */ + auxinfo_t aux; + bli_auxinfo_set_ukr( (void*)gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ + auto caucus = bli_thrinfo_sub_node( thread ); + + /* Query the number of threads and thread ids for each loop. */ + auto jr_nt = bli_thread_n_way( thread ); + auto jr_tid = bli_thread_work_id( thread ); + auto ir_nt = bli_thread_n_way( caucus ); + auto ir_tid = bli_thread_work_id( caucus ); + + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + /* Loop over the n dimension (NR columns at a time). */ + for ( auto j = jr_start; j < jr_end; j += jr_inc ) + { + auto b1 = b_cast + j * cstep_b; + + auto n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + /* Initialize our next panel of B to be the current panel of B. */ + auto b2 = b1; + + /* Loop over the m dimension (MR rows at a time). */ + for ( auto i = ir_start; i < ir_end; i += ir_inc ) + { + auto a1 = a_cast + i * rstep_a; + auto rscat_c1 = rscat_c + i * MR; + auto rbs_c1 = rbs_c + i * MR; + auto cscat_c1 = cscat_c + j * NR; + auto cbs_c1 = cbs_c + j * NR; + + auto m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + /* Compute the addresses of the next panels of A and B. */ + auto a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + auto rs_c = *rbs_c1; + auto cs_c = *cbs_c1; + + if ( rs_c && cs_c ) + { + auto c11 = c_cast + ( *rscat_c1 + *cscat_c1 ) * dt_size; + + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + /* Scatter to C. */ + scatter_mxn[ dt ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c_cast, rscat_c1, cscat_c1 + ); + } + } + } +} + +static bool has_unit_stride( const std::vector& stride ) +{ + for ( auto s : stride ) + if ( s == 1 ) + return true; + return false; +} + +void tcontract( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, std::vector rs_a, std::vector cs_a, + const void* b, std::vector rs_b, std::vector cs_b, + const void* beta, void* c, std::vector rs_c, std::vector cs_c ) +{ + if ( rs_a.size() != m.size() || + rs_b.size() != k.size() || + rs_c.size() != m.size() ) + bli_check_error_code( BLIS_INVALID_ROW_STRIDE ); + + if ( cs_a.size() != k.size() || + cs_b.size() != n.size() || + cs_c.size() != n.size() ) + bli_check_error_code( BLIS_INVALID_COL_STRIDE ); + + dim_t m_mat = 1; + dim_t n_mat = 1; + dim_t k_mat = 1; + for ( auto& i : m ) m_mat *= i; + for ( auto& i : n ) n_mat *= i; + for ( auto& i : k ) k_mat *= i; + + auto& stride_m = has_unit_stride( rs_c ) ? rs_c : rs_a; + for ( int i = 1;i < m.size(); i++ ) + for ( int j = 0;j < m.size()-i; j++ ) + if ( stride_m[j] > stride_m[j+1] ) + { + std::swap( rs_a[j], rs_a[j+1] ); + std::swap( rs_c[j], rs_c[j+1] ); + } + + auto& stride_n = has_unit_stride( cs_c ) ? cs_c : cs_b; + for ( int i = 1;i < n.size(); i++ ) + for ( int j = 0;j < n.size()-i; j++ ) + if ( stride_n[j] > stride_n[j+1] ) + { + std::swap( cs_b[j], cs_b[j+1] ); + std::swap( cs_c[j], cs_c[j+1] ); + } + + auto& stride_k = has_unit_stride( cs_a ) ? cs_a : rs_b; + for ( int i = 1;i < k.size(); i++ ) + for ( int j = 0;j < k.size()-i; j++ ) + if ( stride_k[j] > stride_k[j+1] ) + { + std::swap( cs_a[j], cs_a[j+1] ); + std::swap( rs_b[j], rs_b[j+1] ); + } + + if ( rs_a.empty() ) rs_a.push_back( 1 ); + if ( cs_a.empty() ) cs_a.push_back( 1 ); + if ( rs_b.empty() ) rs_b.push_back( 1 ); + if ( cs_b.empty() ) cs_b.push_back( 1 ); + if ( rs_c.empty() ) rs_c.push_back( 1 ); + if ( cs_c.empty() ) cs_c.push_back( 1 ); + + obj_t a_o, b_o, c_o; + bli_obj_create_with_attached_buffer( dt, m_mat, k_mat, const_cast(a), rs_a[0], cs_a[0], &a_o ); + bli_obj_create_with_attached_buffer( dt, k_mat, n_mat, const_cast(b), rs_b[0], cs_b[0], &b_o ); + bli_obj_create_with_attached_buffer( dt, m_mat, n_mat, c , rs_c[0], cs_c[0], &c_o ); + + packm_tensor_params_t params_a( m.size(), m.data(), rs_a.data(), + k.size(), k.data(), cs_a.data() ); + packm_tensor_params_t params_b( n.size(), n.data(), cs_b.data(), + k.size(), k.data(), rs_b.data() ); + gemm_tensor_params_t params_c( m.size(), m.data(), rs_c.data(), + n.size(), n.data(), cs_c.data() ); + + bli_obj_set_pack_fn( packm_tensor, &a_o ); + bli_obj_set_pack_fn( packm_tensor, &b_o ); + bli_obj_set_ker_fn( gemm_tensor, &c_o ); + bli_obj_set_pack_params( ¶ms_a, &a_o ); + bli_obj_set_pack_params( ¶ms_b, &b_o ); + bli_obj_set_ker_params( ¶ms_c, &c_o ); + + obj_t alpha_o, beta_o; + bli_obj_create_1x1_with_attached_buffer( dt, const_cast(alpha), &alpha_o ); + bli_obj_create_1x1_with_attached_buffer( dt, const_cast(beta), &beta_o ); + + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha_o, &a_o, &b_o, &beta_o, &c_o, NULL, &rntm ); +} + +int main() +{ + auto N = 5; + + gint_t ndim_a = 4; + gint_t ndim_b = 4; + gint_t ndim_c = 4; + + std::vector len_a(ndim_a, N); + std::vector len_b(ndim_b, N); + std::vector len_c(ndim_c, N); + + std::vector stride_a(ndim_a, 1); + std::vector stride_b(ndim_b, 1); + std::vector stride_c(ndim_c, 1); + for ( gint_t i = 1; i < ndim_a; i++ ) + stride_a[i] = stride_a[i-1] * len_a[i - 1]; + for ( gint_t i = 1; i < ndim_b; i++ ) + stride_b[i] = stride_b[i-1] * len_b[i - 1]; + for ( gint_t i = 1; i < ndim_c; i++ ) + stride_c[i] = stride_c[i-1] * len_c[i - 1]; + + std::vector dim_a(ndim_a); + std::vector dim_b(ndim_b); + std::vector dim_c(ndim_c); + std::iota(dim_a.begin(), dim_a.end(), 0); + std::iota(dim_b.begin(), dim_b.end(), 0); + std::iota(dim_c.begin(), dim_c.end(), 0); + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + do + do + do + { + auto dt = ( num_t )dt_; + + auto ndim_m = (ndim_a + ndim_c - ndim_b)/2; + auto ndim_k = (ndim_a + ndim_b - ndim_c)/2; + + std::vector m(len_a.begin(), len_a.begin()+ndim_m); + std::vector n(len_b.begin()+ndim_k, len_b.end()); + std::vector k(len_b.begin(), len_b.begin()+ndim_k); + + std::vector rs_a(stride_a.begin(), stride_a.begin()+ndim_m); + std::vector cs_a(stride_a.begin()+ndim_m, stride_a.end()); + std::vector rs_b(stride_b.begin(), stride_b.begin()+ndim_k); + std::vector cs_b(stride_b.begin()+ndim_k, stride_b.end()); + std::vector rs_c(stride_c.begin(), stride_c.begin()+ndim_m); + std::vector cs_c(stride_c.begin()+ndim_m, stride_c.end()); + + dim_t m_tot = 1; + dim_t n_tot = 1; + dim_t k_tot = 1; + for ( auto i : m ) m_tot *= i; + for ( auto i : n ) n_tot *= i; + for ( auto i : k ) k_tot *= i; + + obj_t a, b, c, c_ref, norm; + + bli_obj_create( dt, m_tot*k_tot, 1, 1, 1, &a ); + bli_obj_create( dt, k_tot*n_tot, 1, 1, 1, &b ); + bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c ); + bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c_ref ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randv( &a ); + bli_randv( &b ); + bli_randv( &c ); + bli_copyv( &c, &c_ref ); + + tcontract( dt, m, n, k, + bli_obj_buffer_for_const( dt, &BLIS_ONE ), + bli_obj_buffer( &a ), rs_a, cs_a, + bli_obj_buffer( &b ), rs_b, cs_b, + bli_obj_buffer_for_const( dt, &BLIS_ZERO ), + bli_obj_buffer( &c ), rs_c, cs_c ); + + tcontract_ref( dt, m, n, k, + bli_obj_buffer_for_const( dt, &BLIS_ONE ), + bli_obj_buffer( &a ), rs_a, cs_a, + bli_obj_buffer( &b ), rs_b, cs_b, + bli_obj_buffer_for_const( dt, &BLIS_ZERO ), + bli_obj_buffer( &c_ref ), rs_c, cs_c ); + + bli_subv( &c_ref, &c ); + bli_normfv( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, dim_a: [%d,%d,%d,%d], dim_b: [%d,%d,%d,%d], dim_c: [%d,%d,%d,%d], norm: %g\n", + dt, dim_a[0], dim_a[1], dim_a[2], dim_a[3], + dim_b[0], dim_b[1], dim_b[2], dim_b[3], + dim_c[0], dim_c[1], dim_c[2], dim_c[3], + normr / std::sqrt( bli_obj_vector_dim( &c ) ) ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_ref ); + } + while (std::next_permutation(dim_a.begin(), dim_a.end())); + while (std::next_permutation(dim_b.begin(), dim_b.end())); + while (std::next_permutation(dim_c.begin(), dim_c.end())); +} + diff --git a/test/tensor_contraction/tcontract_ref.cxx b/test/tensor_contraction/tcontract_ref.cxx new file mode 100644 index 0000000000..b4cd07f903 --- /dev/null +++ b/test/tensor_contraction/tcontract_ref.cxx @@ -0,0 +1,67 @@ +#include "tcontract_ref.hpp" + +template +void tcontract_ref( const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ) +{ + auto alpha_cast = *( T* )alpha; + auto beta_cast = *( T* )beta; + auto a_cast = ( T* )a; + auto b_cast = ( T* )b; + auto c_cast = ( T* )c; + + for_each(m.size(), m.data(), a_cast, rs_a.data(), c_cast, rs_c.data(), + [&] + { + for_each(n.size(), n.data(), b_cast, cs_b.data(), c_cast, cs_c.data(), + [&] + { + auto ab = convert(0.0); + + for_each(k.size(), k.data(), a_cast, cs_a.data(), b_cast, rs_b.data(), + [&] + { + ab += (*a_cast) * (*b_cast); + }); + + if ( beta_cast == convert(0.0) ) + { + *c_cast = alpha_cast * ab; + } + else + { + *c_cast = alpha_cast * ab + beta_cast * (*c_cast); + } + }); + + assert(b_cast == b); + }); + + assert(a_cast == a); + assert(c_cast == c); +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &tcontract_ref; + +INSERT_GENTFUNC_BASIC0(tcontract_ref); + +static decltype(&tcontract_ref) GENARRAY( tcontract_ref_impl, tcontract_ref ); + +void tcontract_ref( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ) +{ + tcontract_ref_impl[ dt ] + ( + m, n, k, + alpha, a, rs_a, cs_a, + b, rs_b, cs_b, + beta, c, rs_c, cs_c + ); +} + diff --git a/test/tensor_contraction/tcontract_ref.hpp b/test/tensor_contraction/tcontract_ref.hpp new file mode 100644 index 0000000000..99d4380dce --- /dev/null +++ b/test/tensor_contraction/tcontract_ref.hpp @@ -0,0 +1,100 @@ +#include "blis.h" +#include "complex_math.hpp" + +#include +#include +#include + +inline void increment(inc_t, gint_t) {} + +template +void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args) +{ + off += s[i]*n; + increment(n, i, args...); +} + +template +void for_each_impl(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + Body& body, + Args&... args) +{ + std::array i = {}; + assert( ndim <= i.size() ); + + if ( off ) + { + for ( gint_t k = 0; k < ndim; k++ ) + { + i[k] = off % n[k]; + off /= n[k]; + increment(i[k], k, args...); + } + } + + for ( dim_t pos = 0; pos < len; pos++ ) + { + body(); + + for ( gint_t k = 0; k < ndim; k++ ) + { + if ( i[k] == n[k]-1 ) + { + increment(-i[k], k, args...); + i[k] = 0; + } + else + { + increment(1, k, args...); + i[k]++; + break; + } + } + } +} + +template +void for_each(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + T& a, const inc_t* s_a, + Body&& body) +{ + for_each_impl( ndim, n, off, len, body, a, s_a ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + T& a, const inc_t* s_a, + T& b, const inc_t* s_b, + Body&& body) +{ + for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + T& a, const inc_t* s_a, + Body&& body) +{ + dim_t len = 1; + for ( gint_t i = 0;i < ndim;i++ ) len *= n[i]; + for_each_impl( ndim, n, 0, len, body, a, s_a ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + T& a, const inc_t* s_a, + T& b, const inc_t* s_b, + Body&& body) +{ + dim_t len = 1; + for ( gint_t i = 0;i < ndim;i++ ) len *= n[i]; + for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b ); +} + +void tcontract_ref( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ); diff --git a/test/test_axpbyv.c b/test/test_axpbyv.c new file mode 100644 index 0000000000..28be2542cb --- /dev/null +++ b/test/test_axpbyv.c @@ -0,0 +1,293 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define PRINT +#ifdef BLIS_ENABLE_CBLAS +//#define CHECK_CBLAS +#endif + +#ifdef CHECK_CBLAS +#include "cblas.h" +#endif + +/* + * BLIS interface API will be called by default. + * To call BLAS API, modify line 159 to '#if 0'. + * To call cblas API, modify line 159 to '#if 0'and define the + * macro 'CHECK_CBLAS' in line 44 + * + *Sample prototype for BLAS interface API is as follows: + * n alpha x incx beta y incy + *void daxpbyv_( int*, double*, double*, int*, double*, double*, int* ); + */ + +int main( int argc, char** argv ) +{ + obj_t x, y; + obj_t y_save; + obj_t alpha, beta; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input; + num_t dt_x, dt_y; + num_t dt_alpha, dt_beta; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double gflops; + + bli_init(); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 40; + p_end = 4000; + p_inc = 40; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = 15; +#endif + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + + dt_x = dt_y = dt_alpha = dt_beta = dt; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; +#ifdef BLIS + printf( "data_axpbyv_blis" ); +#else + printf( "data_axpbyv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, 0.0 ); + + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) + { + + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + + bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt_beta, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt_x, n, 1, 0, 0, &x ); + bli_obj_create( dt_y, n, 1, 0, 0, &y ); + bli_obj_create( dt_y, n, 1, 0, 0, &y_save ); + + bli_randm( &x ); + bli_randm( &y ); + + bli_setsc( (0.9/1.0), 0.2, &alpha ); + bli_setsc( -(1.1/1.0), 0.3, &beta ); + + bli_copym( &y, &y_save ); + + dtime_save = 1.0e9; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &y_save, &y ); + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "alpha", &alpha, "%4.1f", "" ); + bli_printm( "beta" , &beta, "%4.1f", "" ); + + bli_printm( "x", &x, "%4.1f", "" ); + bli_printm( "y", &y, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_axpbyv( &alpha, + &x, + &beta, + &y ); +#else + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float alphap = *(( float * )bli_obj_buffer( &alpha )); + float betap = *(( float * )bli_obj_buffer( &beta )); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_saxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + saxpby_( &nn, + &alphap, + xp, &incx, + &betap, + yp, &incy ); + +#endif + } + else if ( bli_is_double( dt ) ) + { + + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double alphap = *(( double * )bli_obj_buffer( &alpha )); + double betap = *(( double * )bli_obj_buffer( &beta )); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_daxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + daxpby_( &nn, + &alphap, + xp, &incx, + &betap, + yp, &incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + void* alphap = bli_obj_buffer( &alpha ); + void* betap = bli_obj_buffer( &beta ); + void* xp = bli_obj_buffer( &x ); + void* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_caxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + caxpby_( &nn, + ( scomplex* )alphap, + ( scomplex* )xp, &incx, + ( scomplex* )betap, + ( scomplex* )yp, &incy ); +#endif + } + else if ( bli_is_dcomplex( dt )) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + void* alphap = bli_obj_buffer( &alpha ); + void* betap = bli_obj_buffer( &beta ); + void* xp = bli_obj_buffer( &x ); + void* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_zaxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + zaxpby_( &nn, + ( dcomplex* )alphap, + ( dcomplex* )xp, &incx, + ( dcomplex* )betap, + ( dcomplex* )yp, &incy ); +#endif + } +#endif + +#ifdef PRINT + bli_printm( "y after", &y, "%4.1f", "" ); + exit(1); +#endif + + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 ); + +#ifdef BLIS + printf( "data_axpbyv_blis" ); +#else + printf( "data_axpbyv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )n, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + bli_obj_free( &y_save ); + } + + bli_finalize(); + + return 0; +} diff --git a/test/test_axpyv.c b/test/test_axpyv.c index 268e3ea0de..44a0d2d746 100644 --- a/test/test_axpyv.c +++ b/test/test_axpyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,7 +33,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // n alpha x incx y incy @@ -96,10 +100,11 @@ int main( int argc, char** argv ) printf( "data_axpyv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); @@ -188,7 +193,7 @@ int main( int argc, char** argv ) printf( "data_axpyv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )n, gflops ); bli_obj_free( &alpha ); diff --git a/test/test_copyv.c b/test/test_copyv.c new file mode 100644 index 0000000000..a85004f12d --- /dev/null +++ b/test/test_copyv.c @@ -0,0 +1,218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +//#define BLIS_ACCURACY_TEST +#ifdef BLIS_ACCURACY_TEST + +bool scompare_result( int n, float *x, int incx, float *y, int incy ) +{ + for ( int i = 0; i < n; i++ ) + { + if ( (*x) != (*y) ) + { + printf( "%4f != %4f at location %d\n", *x, *y, i ); + return FALSE; + } + x += incx; + y += incy; + } + return TRUE; +} + +bool dcompare_result( int n, double *x, int incx, double *y, int incy ) +{ + for ( int i = 0; i < n; i++ ) + { + if ( (*x) != (*y) ) + { + printf( "%4f != %4f at location %d\n", *x, *y, i ); + return FALSE; + } + x += incx; + y += incy; + } + return TRUE; +} + +#endif + + +int main( int argc, char** argv ) +{ + obj_t x, y; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input, sizeof_dt; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double gbps; + + //bli_init(); + + n_repeats = 100000; + +#ifndef PRINT + p_begin = 200; + p_end = 100000; + p_inc = 200; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = 16; +#endif + +#if 1 + // dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + if ( dt == BLIS_FLOAT ) sizeof_dt = sizeof( float ); + else if ( dt == BLIS_DOUBLE ) sizeof_dt = sizeof( double ); + + printf( "executable\t n\t GBs per sec\n" ); + + for ( p = p_begin; p <= p_end; p += p_inc ) + { + + if ( n_input < 0 ) n = p * ( dim_t )abs( n_input ); + else n = ( dim_t ) n_input; + + bli_obj_create( dt, n, 1, 0, 0, &x ); + bli_obj_create( dt, n, 1, 0, 0, &y ); + + bli_randm( &x ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + + dtime = bli_clock(); + +#ifdef BLIS + bli_copyv( &x, + &y ); +#else + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + float* xp = bli_obj_buffer( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float* yp = bli_obj_buffer( &y ); + + scopy_( &nn, + xp, &incx, + yp, &incy ); + + } + else if ( bli_is_double( dt ) ) + { + + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + double* xp = bli_obj_buffer( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double* yp = bli_obj_buffer( &y ); + + dcopy_( &nn, + xp, &incx, + yp, &incy ); + } +#endif + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + +#ifdef BLIS_ACCURACY_TEST + if ( dt == BLIS_FLOAT ) + { + int nn = bli_obj_length( &x ); + int incx = bli_obj_vector_inc( &x ); + float* xp = bli_obj_buffer( &x ); + int incy = bli_obj_vector_inc( &y ); + float* yp = bli_obj_buffer( &y ); + if ( scompare_result( nn, xp, incx, yp, incy ) ) + printf( "Copy Successful\n" ); + else + printf( "ALERT!!! Copy Failed\n" ); + } + if ( dt == BLIS_DOUBLE ) + { + int nn = bli_obj_length( &x ); + int incx = bli_obj_vector_inc( &x ); + double* xp = bli_obj_buffer( &x ); + int incy = bli_obj_vector_inc( &y ); + double* yp = bli_obj_buffer( &y ); + if ( dcompare_result( nn, xp, incx, yp, incy ) ) + printf( "Copy Successful\n" ); + else + printf( "ALERT!!! Copy Failed\n" ); + } +#endif + } + + // Size of the vectors are incrementd by 1000, to test wide range of inputs. + if ( p >= 1000 ) p_inc = 1000; + if ( p >= 10000 ) p_inc = 10000; + gbps = ( n * sizeof_dt ) / ( dtime_save * 1.0e9 ); + +#ifdef BLIS + printf( "data_copyv_blis\t" ); +#else + printf( "data_copyv_%s\t", BLAS ); +#endif + printf( "%4lu\t %7.2f\n", + ( unsigned long )n, gbps ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + //bli_finalize(); + + return 0; +} diff --git a/test/test_dotv.c b/test/test_dotv.c index ea0f7e4c58..dfba9ef0fe 100644 --- a/test/test_dotv.c +++ b/test/test_dotv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,7 +33,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // res n x incx y incy @@ -93,10 +97,11 @@ int main( int argc, char** argv ) printf( "data_dotv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); @@ -172,7 +177,7 @@ int main( int argc, char** argv ) printf( "data_dotv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )n, gflops ); bli_obj_free( &x ); diff --git a/test/test_gemm.c b/test/test_gemm.c index 5d6b6aa9af..9ba97a6395 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" @@ -105,12 +109,13 @@ int main( int argc, char** argv ) printf( "data_gemm_%s", BLAS ); #endif printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; @@ -287,7 +292,7 @@ int main( int argc, char** argv ) printf( "data_gemm_%s", BLAS ); #endif printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, ( unsigned long )n, gflops ); diff --git a/test/test_gemm3m.c b/test/test_gemm3m.c new file mode 100644 index 0000000000..8e70429013 --- /dev/null +++ b/test/test_gemm3m.c @@ -0,0 +1,352 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" +#include "cblas.h" + +#define CBLAS +//#define FILE_IN_OUT +//#define PRINT +#define MATRIX_INITIALISATION + +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + dim_t p; + dim_t p_begin, p_end, p_inc; + int m_input, n_input, k_input; + num_t dt; + int r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; +#ifdef FILE_IN_OUT + FILE* fin = NULL; + FILE* fout = NULL; +#endif + //bli_init(); + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 200; + p_end = 2000; + p_inc = 100; + + m_input = -1; + n_input = -1; + k_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + m_input = 5; + k_input = 6; + n_input = 4; +#endif + + dt = BLIS_SCOMPLEX; + //dt = BLIS_DCOMPLEX; + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + // printf("BLIS Library version is : %s\n", bli_info_get_version_str()); + +#ifdef FILE_IN_OUT + if ( argc < 3 ) + { + printf( "Usage: ./test_gemm_XX.x input.csv output.csv\n" ); + exit(1); + } + fin = fopen( argv[1], "r" ); + if ( fin == NULL ) + { + printf( "Error opening the file %s\n", argv[1] ); + exit(1); + } + fout = fopen( argv[2], "w" ); + if ( fout == NULL ) + { + printf( "Error opening output file %s\n", argv[2] ); + exit(1); + } + + fprintf( fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\t GEMM_Algo\n" ); + printf( "~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\t GEMM_Algo\n" ); + + inc_t cs_a; + inc_t cs_b; + inc_t cs_c; + + while ( fscanf(fin, "%lld %lld %lld %lld %lld %lld\n", &m, &k, &n, &cs_a, &cs_b, &cs_c) == 6 ) + { + if ( ( m > cs_a ) || + ( k > cs_b ) || + ( m > cs_c ) ) continue; // leading dimension should be greater than number of rows + + bli_obj_create( dt, 1, 1, 0, 0, &alpha); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, k, 1, cs_a, &a ); + bli_obj_create( dt, k, n, 1, cs_b, &b ); + bli_obj_create( dt, m, n, 1, cs_c, &c ); + bli_obj_create( dt, m, n, 1, cs_c, &c_save ); +#ifdef MATRIX_INITIALISATION + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); +#endif + bli_obj_set_conjtrans( transa, &a); + bli_obj_set_conjtrans( transb, &b); + + //bli_setsc( 0.0, -1, &alpha ); + //bli_setsc( 0.0, 1, &beta ); + + bli_setsc( -1, 0.0, &alpha ); + bli_setsc( 1, 0.0, &beta ); + +#else + for ( p = p_begin; p <= p_end; p += p_inc ) + { + if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, k, 0, 0, &a ); + bli_obj_create( dt, k, n, 0, 0, &b ); + bli_obj_create( dt, m, n, 0, 0, &c ); + bli_obj_create( dt, m, n, 0, 0, &c_save ); +#ifdef MATRIX_INITIALISATION + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); +#endif + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (0.9/1.0), 0.2, &alpha ); + bli_setsc( -(1.1/1.0), 0.3, &beta ); + +#endif + bli_copym( &c, &c_save ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + dtime = bli_clock(); + + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#ifndef CBLAS + + if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cgemm3m_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + zgemm3m_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#else + if ( bli_is_scomplex( dt ) ) + { + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* cp = bli_obj_buffer( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* betap = bli_obj_buffer( &beta ); + cblas_cgemm3m( CblasColMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + (const void*)alphap, + ap, m, + bp, k, + (const void*)betap, + cp, m ); + } + else if (bli_is_dcomplex(dt)) + { + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* cp = bli_obj_buffer( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* betap = bli_obj_buffer( &beta ); + cblas_zgemm3m( CblasColMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + (const void*)alphap, + ap, m, + bp, k, + (const void*)betap, + cp, m ); + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.6f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + gflops *= 4.0; //to represent complex ops in gflops + +#ifdef BLIS + printf( "data_gemm_blis" ); +#else + printf( "data_gemm_%s", BLAS ); +#endif + +#ifdef FILE_IN_OUT + + printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops); + + + fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f \n", \ + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops); + fflush(fout); + +#else + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, gflops ); +#endif + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); +#ifdef FILE_IN_OUT + fclose( fin ); + fclose( fout ); +#endif + return 0; +} diff --git a/test/test_gemm_batch.c b/test/test_gemm_batch.c new file mode 100644 index 0000000000..5660e4150e --- /dev/null +++ b/test/test_gemm_batch.c @@ -0,0 +1,584 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define CHECK_CBLAS +#ifdef CHECK_CBLAS +#include "cblas.h" +#endif + +/* Format for FILE input + * For each input set, first line contains 'storage scheme' + * and 'group count' seperated by space. + * Following 'group_count' number of lines contains all the parameters of + * each group separated by space in each line in the following order: + * tA tB m n k lda ldb ldc alpha_r alpha_i beta_r beta_i group_size + * + * Example: + * c 2 + * n n 4 8 4 4 4 4 1.1 0.0 0.9 0.0 2 + * n n 3 3 6 3 6 3 1.0 0.0 2.0 0.0 2 + * + */ + +//#define FILE_IN_OUT +#ifndef FILE_IN_OUT +#define GRP_COUNT 2 +#endif + +//#define PRINT + +int main( int argc, char** argv ) +{ + num_t dt; + + char stor_scheme; + dim_t i, j, idx; + dim_t r, n_repeats; + + double dtime; + double dtime_save; + double gflops; + + dim_t total_count = 0; + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + dt = BLIS_SCOMPLEX; + //dt = BLIS_DCOMPLEX; +#endif + + n_repeats = 1; + +#ifdef FILE_IN_OUT + FILE* fin = NULL; + FILE* fout = NULL; + + if(argc < 3) + { + printf("Usage: ./test_gemm_batch_XX.x input.csv output.csv\n"); + exit(1); + } + + fin = fopen(argv[1], "r"); + if( fin == NULL ) + { + printf("Error opening input file %s \n", argv[1]); + exit(1); + } + + fout = fopen(argv[2], "w"); + if(fout == NULL) + { + printf("Error opening output file %s\n",argv[2]); + exit(1); + } + + dim_t GRP_COUNT; + + fprintf(fout, "m\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t grp_size\n"); + + while(fscanf(fin, "%c %ld\n", &stor_scheme, &GRP_COUNT) == 2) + { + char transa[GRP_COUNT]; + char transb[GRP_COUNT]; + + dim_t m[GRP_COUNT]; + dim_t n[GRP_COUNT]; + dim_t k[GRP_COUNT]; + + dim_t lda[GRP_COUNT]; + dim_t ldb[GRP_COUNT]; + dim_t ldc[GRP_COUNT]; + + double alpha_real[GRP_COUNT]; + double alpha_imag[GRP_COUNT]; + double beta_real[GRP_COUNT]; + double beta_imag[GRP_COUNT]; + + dim_t group_size[GRP_COUNT]; + obj_t alpha[GRP_COUNT], beta[GRP_COUNT]; + + total_count = 0; + for(i = 0; i < GRP_COUNT; i++) + { + fscanf(fin, "%c %c %ld %ld %ld %ld %ld %ld %lf %lf %lf %lf %ld\n", &transa[i], &transb[i], &m[i], &n[i], &k[i], &lda[i], &ldb[i], &ldc[i], &alpha_real[i], &alpha_imag[i], &beta_real[i], &beta_imag[i], &group_size[i]); + + total_count += group_size[i]; + } +#else + printf("m\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t grp_size\n"); + + stor_scheme = 'c'; + + dim_t m[GRP_COUNT] = {4, 3}; + dim_t n[GRP_COUNT] = {8, 3}; + dim_t k[GRP_COUNT] = {4, 6}; + + dim_t lda[GRP_COUNT] = {4, 3}; + dim_t ldb[GRP_COUNT] = {4, 6}; + dim_t ldc[GRP_COUNT] = {4, 3}; + + char transa[GRP_COUNT] = {'N', 'N'}; + char transb[GRP_COUNT] = {'N', 'N'}; + + double alpha_real[GRP_COUNT] = {1.1, 1.0}; + double alpha_imag[GRP_COUNT] = {0.0, 0.0}; + + double beta_real[GRP_COUNT] = {0.9, 2.0}; + double beta_imag[GRP_COUNT] = {0.0, 0.0}; + + dim_t group_size[GRP_COUNT] = {2,2}; + + obj_t alpha[GRP_COUNT], beta[GRP_COUNT]; + + total_count = 0; + for(i = 0; i < GRP_COUNT; i++) + total_count += group_size[i]; + +#endif + obj_t a[total_count], b[total_count]; + obj_t c[total_count], c_save[total_count]; + f77_int f77_m[GRP_COUNT], f77_n[GRP_COUNT], f77_k[GRP_COUNT]; + f77_int f77_lda[GRP_COUNT], f77_ldb[GRP_COUNT], f77_ldc[GRP_COUNT]; + f77_int f77_group_size[GRP_COUNT]; + f77_int f77_group_count = GRP_COUNT; +#ifdef CHECK_CBLAS + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa[GRP_COUNT]; + enum CBLAS_TRANSPOSE cblas_transb[GRP_COUNT]; + + if(stor_scheme == 'R' || stor_scheme == 'r') + cblas_order = CblasRowMajor; + else + cblas_order = CblasColMajor; + +#else + f77_char f77_transa[GRP_COUNT]; + f77_char f77_transb[GRP_COUNT]; + + if(stor_scheme == 'r' || stor_scheme == 'R' ) + { + printf("BLAS Interface doesn't support row-major order\n"); +#ifdef FILE_IN_OUT + continue; +#else + exit(1); +#endif + } +#endif + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + bli_obj_create(dt, 1, 1, 0, 0, &alpha[i]); + bli_obj_create(dt, 1, 1, 0, 0, &beta[i] ); + + bli_setsc(alpha_real[i], alpha_imag[i], &alpha[i]); + bli_setsc(beta_real[i], beta_imag[i], &beta[i] ); + + trans_t blis_transa, blis_transb; + if(transa[i] == 't' || transa[i] == 'T') + blis_transa = BLIS_TRANSPOSE; + else if (transa[i] == 'c' || transa[i] == 'C') + blis_transa = BLIS_CONJ_TRANSPOSE; + else if ( transa[i] == 'n' || transa[i] == 'N') + blis_transa = BLIS_NO_TRANSPOSE; + else + { + printf("Illegal transA setting %c for group %ld\n", transa[i], i); + exit(1); + } + + if(transb[i] == 't' || transb[i] == 'T') + blis_transb = BLIS_TRANSPOSE; + else if (transb[i] == 'c' || transb[i] == 'C') + blis_transb = BLIS_CONJ_TRANSPOSE; + else if (transb[i] == 'n' || transb[i] == 'N') + blis_transb = BLIS_NO_TRANSPOSE; + else + { + printf("Illegal transB setting %c for group %ld\n", transb[i], i); + exit(1); + } +#ifdef CHECK_CBLAS + if(bli_is_trans( blis_transa )) + cblas_transa[i] = CblasTrans; + else if (bli_is_conjtrans( blis_transa )) + cblas_transa[i] = CblasConjTrans; + else + cblas_transa[i] = CblasNoTrans; + + if(bli_is_trans( blis_transb )) + cblas_transb[i] = CblasTrans; + else if (bli_is_conjtrans( blis_transb )) + cblas_transb[i] = CblasConjTrans; + else + cblas_transb[i] = CblasNoTrans; +#else + bli_param_map_blis_to_netlib_trans( blis_transa, &f77_transa[i]); + bli_param_map_blis_to_netlib_trans( blis_transb, &f77_transb[i]); + +#endif + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + bli_set_dims_with_trans( blis_transa, m[i], k[i], &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k[i], n[i], &m0_b, &n0_b ); + if(stor_scheme == 'C' || stor_scheme == 'c') + { + for(j = 0; j < group_size[i]; j++) + { + bli_obj_create(dt, m0_a, n0_a, 1, lda[i], &a[idx]); + bli_obj_create(dt, m0_b, n0_b, 1, ldb[i], &b[idx]); + bli_obj_create(dt, m[i], n[i], 1, ldc[i], &c[idx]); + bli_obj_create(dt, m[i], n[i], 1, ldc[i], &c_save[idx]); + + bli_randm( &a[idx] ); + bli_randm( &b[idx] ); + bli_randm( &c[idx] ); + + bli_obj_set_conjtrans(blis_transa, &a[idx]); + bli_obj_set_conjtrans(blis_transb, &b[idx]); + idx++; + } + } + else if(stor_scheme == 'R' || stor_scheme == 'r') + { + for(j = 0; j < group_size[i]; j++) + { + bli_obj_create(dt, m0_a, n0_a, lda[i], 1, &a[idx]); + bli_obj_create(dt, m0_b, n0_b, ldb[i], 1, &b[idx]); + bli_obj_create(dt, m[i], n[i], ldc[i], 1, &c[idx]); + bli_obj_create(dt, m[i], n[i], ldc[i], 1, &c_save[idx]); + + bli_randm( &a[idx] ); + bli_randm( &b[idx] ); + bli_randm( &c[idx] ); + + bli_obj_set_conjtrans(blis_transa, &a[idx]); + bli_obj_set_conjtrans(blis_transb, &b[idx]); + idx++; + } + } + f77_m[i] = m[i]; + f77_n[i] = n[i]; + f77_k[i] = k[i]; + f77_lda[i] = lda[i]; + f77_ldb[i] = ldb[i]; + f77_ldc[i] = ldc[i]; + f77_group_size[i] = group_size[i]; + + } + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + bli_copym(&c[idx], &c_save[idx]); + idx++; + } + + dtime_save = DBL_MAX; + + for( r = 0; r < n_repeats; ++r ) + { + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + bli_copym( &c_save[idx], &c[idx]); + idx++; + } + + dtime = bli_clock(); + +#ifdef PRINT + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + printf("Group: %ld Member: %ld\n", i, j); + + bli_printm("a", &a[idx], "%4.1f", ""); + bli_printm("b", &b[idx], "%4.1f", ""); + bli_printm("c", &c[idx], "%4.1f", ""); + + idx++; + } +#endif + + if(bli_is_float(dt)) + { + const float *ap[total_count], *bp[total_count]; + float *cp[total_count]; + float alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(float*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(float*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } + +#ifdef CHECK_CBLAS + cblas_sgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + sgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + + } + else if(bli_is_double(dt)) + { + const double *ap[total_count], *bp[total_count]; + double *cp[total_count]; + double alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(double*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(double*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } +#ifdef CHECK_CBLAS + cblas_dgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + dgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + + } + else if(bli_is_scomplex(dt)) + { + const scomplex *ap[total_count], *bp[total_count]; + scomplex *cp[total_count]; + scomplex alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(scomplex*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(scomplex*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } +#ifdef CHECK_CBLAS + cblas_cgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + (const void*)alphap, + (const void**)ap, f77_lda, + (const void**)bp, f77_ldb, + (const void*)betap, (void**)cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + cgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + } + else if(bli_is_dcomplex(dt)) + { + const dcomplex *ap[total_count], *bp[total_count]; + dcomplex *cp[total_count]; + dcomplex alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(dcomplex*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(dcomplex*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } + +#ifdef CHECK_CBLAS + cblas_zgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + (const void*)alphap, + (const void**)ap, f77_lda, + (const void**)bp, f77_ldb, + (const void*)betap, (void**)cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + zgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + } +#ifdef PRINT + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + printf("Group: %ld Member: %ld\n", i, j); + bli_printm("c after", &c[idx], "%4.1f", ""); + + idx++; + } +#endif + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + dim_t fp_ops = 0; + for(i = 0; i < GRP_COUNT; i++) + fp_ops += 2.0 * m[i] * k[i] * n[i] * group_size[i]; + + gflops = fp_ops / (dtime_save * 1.0e9 ); + + if(bli_is_complex( dt ) ) gflops *= 4.0; + +#ifdef FILE_IN_OUT + fprintf(fout, "Stor_scheme = %c, group_count = %lu, gflops = %7.2f\n", stor_scheme, GRP_COUNT, gflops); + for(i = 0; i < GRP_COUNT; i++) + fprintf(fout, "%4lu \t %4lu\t %4lu\t %4lu\t %4lu\t %4lu\t %c\t %c\t %4lu\n", m[i], n[i], k[i], lda[i], ldb[i], ldc[i], transa[i], transb[i], group_size[i]); + + fflush(fout); +#else + printf( "Stor_scheme = %c, group_count = %d, gflops = %7.2f\n", stor_scheme, GRP_COUNT, gflops); + for(i = 0; i < GRP_COUNT; i++) + printf("%4lu \t %4lu\t %4lu\t %4lu\t %4lu\t %4lu\t %c\t %c\t %4lu\n", m[i], n[i], k[i], lda[i], ldb[i], ldc[i], transa[i], transb[i], group_size[i]); + +#endif + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + bli_obj_free( &alpha[i]); + bli_obj_free( &beta[i] ); + + for(j = 0; j < group_size[i]; j++ ) + { + bli_obj_free( &a[idx]); + bli_obj_free( &b[idx]); + bli_obj_free( &c[idx]); + bli_obj_free( &c_save[idx]); + + idx++; + } + } +#ifdef FILE_IN_OUT + } + fclose(fin); + fclose(fout); +#endif + return 0; +} + diff --git a/test/test_gemmt.c b/test/test_gemmt.c new file mode 100644 index 0000000000..881b5500b9 --- /dev/null +++ b/test/test_gemmt.c @@ -0,0 +1,483 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define CBLAS +//#define C_STOR_R + +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, k; + dim_t p; + dim_t p_begin, p_end, p_inc; + int m_input, k_input; + num_t dt; + int r, n_repeats; + uplo_t uploc; + trans_t transa; + trans_t transb; + f77_char f77_uploc; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + + //bli_init(); + + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 200; + p_end = 2000; + p_inc = 200; + + m_input = -1; + k_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + m_input = 5; + k_input = 4; +#endif + +#if 1 + //dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + uploc = BLIS_LOWER; + //uploc = BLIS_UPPER; + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + bli_param_map_blis_to_netlib_uplo( uploc, &f77_uploc ); + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + char uplocl = tolower( f77_uploc ); + char transal = tolower( f77_transa ); + char transbl = tolower( f77_transb ); + + f77_int cbla_uploc = ( uplocl == 'l' ? CblasLower : CblasUpper ); + f77_int cbla_transa = ( transal == 'n' ? CblasNoTrans : CblasTrans ); + f77_int cbla_transb = ( transbl == 'n' ? CblasNoTrans : CblasTrans ); + + ( void )cbla_uploc; + ( void )cbla_transa; + ( void )cbla_transb; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; +#ifdef BLIS + printf( "data_gemmt_blis" ); +#else + printf( "data_gemmt_%s", BLAS ); +#endif + printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) + { + if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + +#ifndef C_STOR_R + if ( bli_does_trans( transa ) ) + bli_obj_create( dt, k, m, 0, 0, &a ); + else + bli_obj_create( dt, m, k, 0, 0, &a ); + + if ( bli_does_trans( transb ) ) + bli_obj_create( dt, m, k, 0, 0, &b ); + else + bli_obj_create( dt, k, m, 0, 0, &b ); + + bli_obj_create( dt, m, m, 0, 0, &c ); + bli_obj_create( dt, m, m, 0, 0, &c_save ); +#else + if ( bli_does_trans( transa ) ) + bli_obj_create( dt, k, m, -1, -1, &a ); + else + bli_obj_create( dt, m, k, -1, -1, &a ); + + if ( bli_does_trans( transb ) ) + bli_obj_create( dt, m, k, -1, -1, &b ); + else + bli_obj_create( dt, k, m, -1, -1, &b ); + + bli_obj_create( dt, m, m, -1, -1, &c ); + bli_obj_create( dt, m, m, -1, -1, &c_save ); +#endif + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_uplo( uploc, &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (0.9/1.0), 0.2, &alpha ); + bli_setsc( -(1.1/1.0), 0.3, &beta ); + + + bli_copym( &c, &c_save ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + + dtime = bli_clock(); + + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_gemmt( &alpha, + &a, + &b, + &beta, + &c ); + +#else + +#ifndef CBLAS + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + sgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &mm, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + dgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &mm, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &mm, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + zgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &mm, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + +#else // #ifdef CBLAS + + f77_int cbla_storage = ( bli_obj_is_row_stored( &c ) ? CblasRowMajor + : CblasColMajor ); + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); +#ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); +#else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); +#endif + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + cblas_sgemmt( cbla_storage, + cbla_uploc, + cbla_transa, + cbla_transb, + mm, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); +#ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); +#else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); +#endif + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + cblas_dgemmt( cbla_storage, + cbla_uploc, + cbla_transa, + cbla_transb, + mm, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); +#ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); +#else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); +#endif + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cblas_cgemmt( cbla_storage, + cbla_uploc, + cbla_transa, + cbla_transb, + mm, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); +#ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); +#else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); +#endif + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + cblas_zgemmt( cbla_storage, + cbla_uploc, + cbla_transa, + cbla_transb, + mm, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } +#endif + +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.1f", "" ); + exit(1); +#endif + + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 1.0 * m * k * m ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + +#ifdef BLIS + printf( "data_gemmt_blis" ); +#else + printf( "data_gemmt_%s", BLAS ); +#endif + printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )k, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/test_gemv.c b/test/test_gemv.c index 7d15c3249a..4cc60eefab 100644 --- a/test/test_gemv.c +++ b/test/test_gemv.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // transa m n alpha a lda x incx beta y incy @@ -88,11 +92,12 @@ int main( int argc, char** argv ) printf( "data_gemv_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -183,7 +188,7 @@ int main( int argc, char** argv ) printf( "data_gemv_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/test_ger.c b/test/test_ger.c index e3497703e6..7d69654959 100644 --- a/test/test_ger.c +++ b/test/test_ger.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // m n alpha x incx y incy a lda @@ -88,11 +92,12 @@ int main( int argc, char** argv ) printf( "data_ger_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -176,7 +181,7 @@ int main( int argc, char** argv ) printf( "data_ger_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/test_hemm.c b/test/test_hemm.c index 40068c5f95..4e004546f0 100644 --- a/test/test_hemm.c +++ b/test/test_hemm.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" @@ -106,11 +110,12 @@ int main( int argc, char** argv ) printf( "data_hemm_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; @@ -298,7 +303,7 @@ int main( int argc, char** argv ) printf( "data_hemm_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/test_hemv.c b/test/test_hemv.c index 0250d31b8d..2ec544fdd9 100644 --- a/test/test_hemv.c +++ b/test/test_hemv.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // uploa m alpha a lda x incx beta y incy @@ -93,10 +97,11 @@ int main( int argc, char** argv ) printf( "data_hemv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -190,7 +195,7 @@ int main( int argc, char** argv ) printf( "data_hemv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, gflops ); bli_obj_free( &alpha ); diff --git a/test/test_her.c b/test/test_her.c index 026b91261b..267e1bfe02 100644 --- a/test/test_her.c +++ b/test/test_her.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +33,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // uplo m alpha x incx a lda @@ -77,11 +82,8 @@ int main( int argc, char** argv ) m_input = 6; #endif -#if 1 - dt_alpha = dt_x = dt_a = BLIS_DOUBLE; -#else + // her supports complex and double complex dt_alpha = dt_x = dt_a = BLIS_DCOMPLEX; -#endif uplo = BLIS_LOWER; @@ -94,10 +96,11 @@ int main( int argc, char** argv ) printf( "data_her_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -122,7 +125,7 @@ int main( int argc, char** argv ) bli_copym( &a, &a_save ); - + dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) @@ -138,33 +141,76 @@ int main( int argc, char** argv ) #endif #ifdef BLIS - //bli_obj_toggle_conj( &x ); - //bli_syr( &alpha, bli_her( &alpha, &x, &a ); #else - - f77_char uplo = 'L'; - f77_int mm = bli_obj_length( &a ); - f77_int incx = bli_obj_vector_inc( &x ); - f77_int lda = bli_obj_col_stride( &a ); - double* alphap = bli_obj_buffer( &alpha ); - double* xp = bli_obj_buffer( &x ); - double* ap = bli_obj_buffer( &a ); -/* - dcomplex* xp = bli_obj_buffer( x ); - dcomplex* ap = bli_obj_buffer( &a ); -*/ - - dsyr_( &uplo, - //zher_( &uplo, - &mm, - alphap, - xp, &incx, - ap, &lda ); + if ( bli_is_float( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int lda = bli_obj_col_stride( &a ); + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* ap = bli_obj_buffer( &a ); + + ssyr_( &uplo, + &mm, + alphap, + xp, &incx, + ap, &lda ); + } + else if ( bli_is_double( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int lda = bli_obj_col_stride( &a ); + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* ap = bli_obj_buffer( &a ); + + dsyr_( &uplo, + &mm, + alphap, + xp, &incx, + ap, &lda ); + } + else if ( bli_is_scomplex( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int lda = bli_obj_col_stride( &a ); + float* alphap = bli_obj_buffer( &alpha ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* ap = bli_obj_buffer( &a ); + + cher_( &uplo, + &mm, + alphap, + xp, &incx, + ap, &lda ); + } + else if ( bli_is_dcomplex( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int lda = bli_obj_col_stride( &a ); + double* alphap = bli_obj_buffer( &alpha ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* ap = bli_obj_buffer( &a ); + + zher_( &uplo, + &mm, + alphap, + xp, &incx, + ap, &lda ); + } #endif #ifdef PRINT @@ -184,7 +230,7 @@ int main( int argc, char** argv ) printf( "data_her_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, gflops ); bli_obj_free( &alpha ); diff --git a/test/test_her2.c b/test/test_her2.c index 7428dde4ec..3672051dd2 100644 --- a/test/test_her2.c +++ b/test/test_her2.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,12 +33,16 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // uplo m alpha x incx y incy a lda //void dsyr2_( char*, int*, double*, double*, int*, double*, int*, double*, int* ); - + //#define PRINT int main( int argc, char** argv ) @@ -76,11 +81,8 @@ int main( int argc, char** argv ) m_input = 6; #endif -#if 1 - dt_alpha = dt_x = dt_y = dt_a = BLIS_DOUBLE; -#else - dt_alpha = dt_x = dt_y = dt_a = BLIS_DCOMPLEX; -#endif + // her2 supports complex and double complex + dt_alpha = dt_x = dt_y = dt_a = BLIS_SCOMPLEX; uplo = BLIS_LOWER; @@ -93,10 +95,11 @@ int main( int argc, char** argv ) printf( "data_her2_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -123,7 +126,7 @@ int main( int argc, char** argv ) bli_copym( &a, &a_save ); - + dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) @@ -137,37 +140,93 @@ int main( int argc, char** argv ) bli_printm( "x", &x, "%4.1f", "" ); bli_printm( "y", &y, "%4.1f", "" ); bli_printm( "a", &a, "%4.1f", "" ); -#endif +#endif #ifdef BLIS - //bli_obj_toggle_conj( &x ); - //bli_obj_toggle_conj( &y ); - - //bli_syr2( &alpha, bli_her2( &alpha, &x, &y, &a ); #else + if ( bli_is_float( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + f77_int lda = bli_obj_col_stride( &a ); + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); + float* ap = bli_obj_buffer( &a ); + + ssyr2_( &uplo, + &mm, + alphap, + xp, &incx, + yp, &incy, + ap, &lda ); + } + else if ( bli_is_double( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + f77_int lda = bli_obj_col_stride( &a ); + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); + double* ap = bli_obj_buffer( &a ); + + dsyr2_( &uplo, + &mm, + alphap, + xp, &incx, + yp, &incy, + ap, &lda ); + } + else if ( bli_is_scomplex( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + f77_int lda = bli_obj_col_stride( &a ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* yp = bli_obj_buffer( &y ); + scomplex* ap = bli_obj_buffer( &a ); + + cher2_( &uplo, + &mm, + alphap, + xp, &incx, + yp, &incy, + ap, &lda ); + } + else if ( bli_is_dcomplex( dt_a ) ) + { + f77_char uplo = 'L'; + f77_int mm = bli_obj_length( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + f77_int lda = bli_obj_col_stride( &a ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* yp = bli_obj_buffer( &y ); + dcomplex* ap = bli_obj_buffer( &a ); + + zher2_( &uplo, + &mm, + alphap, + xp, &incx, + yp, &incy, + ap, &lda ); + } - f77_char uplo = 'L'; - f77_int mm = bli_obj_length( &a ); - f77_int incx = bli_obj_vector_inc( &x ); - f77_int incy = bli_obj_vector_inc( &y ); - f77_int lda = bli_obj_col_stride( &a ); - double* alphap = bli_obj_buffer( &alpha ); - double* xp = bli_obj_buffer( &x ); - double* yp = bli_obj_buffer( &y ); - double* ap = bli_obj_buffer( &a ); - - dsyr2_( &uplo, - &mm, - alphap, - xp, &incx, - yp, &incy, - ap, &lda ); #endif #ifdef PRINT @@ -186,7 +245,7 @@ int main( int argc, char** argv ) printf( "data_her2_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, gflops ); bli_obj_free( &alpha ); diff --git a/test/test_her2k.c b/test/test_her2k.c index a73e849554..7e8a7b8fe8 100644 --- a/test/test_her2k.c +++ b/test/test_her2k.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +33,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" @@ -80,13 +85,10 @@ int main( int argc, char** argv ) k_input = 1; #endif -#if 1 - //dt = BLIS_FLOAT; - dt = BLIS_DOUBLE; -#else + // her2k supports complex and double complex //dt = BLIS_SCOMPLEX; dt = BLIS_DCOMPLEX; -#endif + uploc = BLIS_LOWER; //uploc = BLIS_UPPER; @@ -105,11 +107,12 @@ int main( int argc, char** argv ) printf( "data_her2k_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; @@ -148,7 +151,7 @@ int main( int argc, char** argv ) bli_copym( &c, &c_save ); - + dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) @@ -176,16 +179,16 @@ int main( int argc, char** argv ) #else if ( bli_is_float( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); ssyr2k_( &f77_uploc, &f77_transa, @@ -199,16 +202,16 @@ int main( int argc, char** argv ) } else if ( bli_is_double( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); dsyr2k_( &f77_uploc, &f77_transa, @@ -222,16 +225,16 @@ int main( int argc, char** argv ) } else if ( bli_is_scomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); cher2k_( &f77_uploc, &f77_transa, @@ -245,16 +248,16 @@ int main( int argc, char** argv ) } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); zher2k_( &f77_uploc, &f77_transa, @@ -287,7 +290,7 @@ int main( int argc, char** argv ) printf( "data_her2k_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, gflops ); diff --git a/test/test_herk.c b/test/test_herk.c index db8f826c9f..cbf963a339 100644 --- a/test/test_herk.c +++ b/test/test_herk.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +33,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" @@ -79,14 +84,10 @@ int main( int argc, char** argv ) m_input = 3; k_input = 1; #endif - -#if 1 - //dt = BLIS_FLOAT; - dt = BLIS_DOUBLE; -#else + + // herk supports complex and double complex //dt = BLIS_SCOMPLEX; dt = BLIS_DCOMPLEX; -#endif uploc = BLIS_LOWER; //uploc = BLIS_UPPER; @@ -105,11 +106,12 @@ int main( int argc, char** argv ) printf( "data_herk_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; @@ -140,7 +142,7 @@ int main( int argc, char** argv ) bli_copym( &c, &c_save ); - + dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) @@ -166,14 +168,14 @@ int main( int argc, char** argv ) #else if ( bli_is_float( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); ssyrk_( &f77_uploc, &f77_transa, @@ -186,14 +188,14 @@ int main( int argc, char** argv ) } else if ( bli_is_double( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); dsyrk_( &f77_uploc, &f77_transa, @@ -206,14 +208,14 @@ int main( int argc, char** argv ) } else if ( bli_is_scomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - float* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + float* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); cherk_( &f77_uploc, &f77_transa, @@ -226,14 +228,14 @@ int main( int argc, char** argv ) } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - double* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + double* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); zherk_( &f77_uploc, &f77_transa, @@ -265,7 +267,7 @@ int main( int argc, char** argv ) printf( "data_herk_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, gflops ); diff --git a/test/test_swapv.c b/test/test_swapv.c new file mode 100644 index 0000000000..4d8d35eac9 --- /dev/null +++ b/test/test_swapv.c @@ -0,0 +1,180 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +// n x incx y incy +//void dswap_( int*, double*, int*, double*, int* ); +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t x, y; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double gflops; + + bli_init(); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 40; + p_end = 8000; + p_inc = 40; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = -1; +#endif + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; +#ifdef BLIS + printf( "data_swapv_blis" ); +#else + printf( "data_swapv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, 0.0 ); + + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) + { + + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + + bli_obj_create( dt, n, 1, 0, 0, &x ); + bli_obj_create( dt, n, 1, 0, 0, &y ); + + bli_randm( &x ); + bli_randm( &y ); + + dtime_save = 1.0e9; + + for ( r = 0; r < n_repeats; ++r ) + { + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "x", &x, "%4.1f", "" ); + bli_printm( "y", &y, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_swapv( &x, + &y ); +#else + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); + + sswap_( &nn, + xp, &incx, + yp, &incy ); + + } + else if ( bli_is_double( dt ) ) + { + + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); + + dswap_( &nn, + xp, &incx, + yp, &incy ); + } +#endif + +#ifdef PRINT + bli_printm( "X after", &x, "%4.1f", "" ); + bli_printm( "Y after", &y, "%4.1f", "" ); + + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( n ) / ( dtime_save * 1.0e9 ); + +#ifdef BLIS + printf( "data_swapv_blis" ); +#else + printf( "data_swapv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )n, gflops ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + bli_finalize(); + + return 0; +} diff --git a/test/test_trmm.c b/test/test_trmm.c index 214ea32beb..1372675431 100644 --- a/test/test_trmm.c +++ b/test/test_trmm.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" @@ -116,11 +120,12 @@ int main( int argc, char** argv ) printf( "data_trmm_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; @@ -282,7 +287,7 @@ int main( int argc, char** argv ) printf( "data_trmm_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/test_trmv.c b/test/test_trmv.c index bd737de9f2..816cb8b40c 100644 --- a/test/test_trmv.c +++ b/test/test_trmv.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // uploa trans, diag, m a lda x incx @@ -90,10 +94,11 @@ int main( int argc, char** argv ) printf( "data_trmv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -176,7 +181,7 @@ int main( int argc, char** argv ) printf( "data_trmv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, gflops ); bli_obj_free( &alpha ); diff --git a/test/test_trsm.c b/test/test_trsm.c index e5796bad34..273e76235c 100644 --- a/test/test_trsm.c +++ b/test/test_trsm.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" @@ -116,11 +120,12 @@ int main( int argc, char** argv ) printf( "data_trsm_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; @@ -285,7 +290,7 @@ int main( int argc, char** argv ) printf( "data_trsm_%s", BLAS ); #endif printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, gflops ); diff --git a/test/test_trsv.c b/test/test_trsv.c index 048fe3950d..8bb29e19d8 100644 --- a/test/test_trsv.c +++ b/test/test_trsv.c @@ -32,7 +32,11 @@ */ +#ifdef WIN32 +#include +#else #include +#endif #include "blis.h" // uploa trans, diag, m a lda x incx @@ -90,10 +94,11 @@ int main( int argc, char** argv ) printf( "data_trv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); @@ -183,7 +188,7 @@ int main( int argc, char** argv ) printf( "data_trsv_%s", BLAS ); #endif printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )(p - p_begin)/p_inc + 1, ( unsigned long )m, gflops ); bli_obj_free( &alpha ); diff --git a/test/thread_ranges/Makefile b/test/thread_ranges/Makefile index 2ed155be1a..5af2ce533c 100644 --- a/test/thread_ranges/Makefile +++ b/test/thread_ranges/Makefile @@ -104,7 +104,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Datatype diff --git a/test/thread_ranges/test_ranges.c b/test/thread_ranges/test_ranges.c index 9bf293ca54..b597ab300b 100644 --- a/test/thread_ranges/test_ranges.c +++ b/test/thread_ranges/test_ranges.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -110,7 +110,7 @@ int main( int argc, char** argv ) dim_t bf; dim_t n_way; char part_dim_ch; - bool_t go_fwd; + dim_t go_fwd; char out_ch; obj_t a; @@ -119,8 +119,8 @@ int main( int argc, char** argv ) thrinfo_t thrinfo; dim_t m, n; uplo_t uploa; - bool_t part_m_dim, part_n_dim; - bool_t go_bwd; + dim_t part_m_dim, part_n_dim; + dim_t go_bwd; dim_t p; num_t dt; dim_t start, end; diff --git a/testsuite/Makefile b/testsuite/Makefile index 1e97cdcf42..57c1c748d3 100644 --- a/testsuite/Makefile +++ b/testsuite/Makefile @@ -103,7 +103,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Binary executable name. TESTSUITE_BIN := test_libblis.x diff --git a/testsuite/input.general b/testsuite/input.general index 7728402241..ae0d73b110 100644 --- a/testsuite/input.general +++ b/testsuite/input.general @@ -31,11 +31,6 @@ sdcz # Datatype(s) to test: 500 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: diff --git a/testsuite/input.general.fast b/testsuite/input.general.fast index 02b30b897d..06a89d16d9 100644 --- a/testsuite/input.general.fast +++ b/testsuite/input.general.fast @@ -31,12 +31,7 @@ sdcz # Datatype(s) to test: 100 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) -0 # 1m ('1' = enable; '0' = disable) +1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: # '1' = disable / use one testsuite thread; diff --git a/testsuite/input.general.mixed b/testsuite/input.general.mixed index 55a3f56c75..36a3e62a67 100644 --- a/testsuite/input.general.mixed +++ b/testsuite/input.general.mixed @@ -31,11 +31,6 @@ sdcz # Datatype(s) to test: 500 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: diff --git a/testsuite/input.general.salt b/testsuite/input.general.salt index ad52b68bba..2e8b8a284e 100644 --- a/testsuite/input.general.salt +++ b/testsuite/input.general.salt @@ -31,11 +31,6 @@ sdcz # Datatype(s) to test: 100 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 4 # Simulate application-level threading: diff --git a/testsuite/input.operations b/testsuite/input.operations index f35e2cd9b5..eebe8b605d 100644 --- a/testsuite/input.operations +++ b/testsuite/input.operations @@ -280,6 +280,10 @@ -1 -1 -1 # dimensions: m n k ?? # parameters: transa transb +1 # gemmt +-1 -1 # dimensions: m k +??? # parameters: uploc transa transb + 1 # hemm -1 -1 # dimensions: m n ???? # parameters: side uploa conja transb diff --git a/testsuite/input.operations.fast b/testsuite/input.operations.fast index c645a35658..b733c672d9 100644 --- a/testsuite/input.operations.fast +++ b/testsuite/input.operations.fast @@ -280,6 +280,10 @@ -1 -1 -1 # dimensions: m n k nn # parameters: transa transb +1 # gemmt +-1 -1 # dimensions: m k +?nn # parameters: uploc transa transb + 1 # hemm -1 -1 # dimensions: m n ??nn # parameters: side uploa conja transb diff --git a/testsuite/input.operations.mixed b/testsuite/input.operations.mixed index f99c6b8717..6292ea8ab4 100644 --- a/testsuite/input.operations.mixed +++ b/testsuite/input.operations.mixed @@ -280,6 +280,10 @@ -1 -1 -1 # dimensions: m n k nn # parameters: transa transb +1 # gemmt +-1 -1 # dimensions: m k +??? # parameters: uploc transa transb + 1 # hemm -1 -1 # dimensions: m n ???? # parameters: side uploa conja transb diff --git a/testsuite/input.operations.salt b/testsuite/input.operations.salt index c645a35658..b733c672d9 100644 --- a/testsuite/input.operations.salt +++ b/testsuite/input.operations.salt @@ -280,6 +280,10 @@ -1 -1 -1 # dimensions: m n k nn # parameters: transa transb +1 # gemmt +-1 -1 # dimensions: m k +?nn # parameters: uploc transa transb + 1 # hemm -1 -1 # dimensions: m n ??nn # parameters: side uploa conja transb diff --git a/testsuite/old/jobscripts/cfig.out b/testsuite/old/jobscripts/cfig.out new file mode 100644 index 0000000000..f8d2707cb4 --- /dev/null +++ b/testsuite/old/jobscripts/cfig.out @@ -0,0 +1,106 @@ +configure: detected Linux kernel version 4.14.0-115.6.1.el7a.ppc64le. +configure: python interpeter search list is: python python3 python2. +configure: using 'python' python interpreter. +configure: found python version 2.7.5 (maj: 2, min: 7, rev: 5). +configure: python 2.7.5 appears to be supported. +configure: C compiler search list is: gcc clang cc. +configure: using 'gcc' C compiler. +configure: C++ compiler search list is: g++ clang++ c++. +configure: using 'g++' C++ compiler (for sandbox only). +configure: found gcc version 8.2.0 (maj: 8, min: 2, rev: 0). +configure: checking for blacklisted configurations due to gcc 8.2.0. +configure: found assembler ('as') version 2.27 (maj: 2, min: 27, rev: ). +configure: checking for blacklisted configurations due to as 2.27. +configure: warning: assembler ('as' 2.27) does not support 'bulldozer'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'sandybridge'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'haswell'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'piledriver'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'steamroller'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'excavator'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'skx'; adding to blacklist. +configure: warning: assembler ('as' 2.27) does not support 'knl'; adding to blacklist. +configure: configuration blacklist: +configure: bulldozer sandybridge haswell piledriver steamroller excavator skx knl +configure: reading configuration registry...done. +configure: determining default version string. +configure: found '.git' directory; assuming git clone. +configure: executing: git describe --tags. +configure: git returned an error: 'Unknown option: -C +usage: git [--version] [--help] [-c name=value] + [--exec-path[=]] [--html-path] [--man-path] [--info-path] + [-p|--paginate|--no-pager] [--no-replace-objects] [--bare] + [--git-dir=] [--work-tree=] [--namespace=] + []'. +configure: using string from unmodified version file. +configure: starting configuration of BLIS 0.6.0. +configure: configuring with official version string. +configure: found shared library .so version '2.0.0'. +configure: .so major version: 2 +configure: .so minor.build version: 0.0 +configure: manual configuration requested; configuring with 'power9'. +configure: checking configuration against contents of 'config_registry'. +configure: configuration 'power9' is registered. +configure: 'power9' is defined as having the following sub-configurations: +configure: power9 +configure: which collectively require the following kernels: +configure: power9 +configure: checking sub-configurations: +configure: 'power9' is registered...and exists. +configure: checking sub-configurations' requisite kernels: +configure: 'power9' kernels...exist. +configure: no install prefix option given; defaulting to '/usr/local'. +configure: no install exec_prefix option given; defaulting to PREFIX. +configure: no install libdir option given; defaulting to EXECPREFIX/lib. +configure: no install includedir option given; defaulting to PREFIX/include. +configure: no install sharedir option given; defaulting to PREFIX/share. +configure: final installation directories: +configure: prefix: /usr/local +configure: exec_prefix: ${prefix} +configure: libdir: ${exec_prefix}/lib +configure: includedir: ${prefix}/include +configure: sharedir: ${prefix}/share +configure: NOTE: the variables above can be overridden when running make. +configure: no preset CFLAGS detected. +configure: no preset LDFLAGS detected. +configure: debug symbols disabled. +configure: disabling verbose make output. (enable with 'make V=1'.) +configure: disabling ARG_MAX hack. +configure: building BLIS as both static and shared libraries. +configure: exporting only public symbols within shared library. +configure: threading is disabled. +configure: requesting slab threading in jr and ir loops. +configure: internal memory pools for packing blocks are enabled. +configure: internal memory pools for small blocks are enabled. +configure: memory tracing output is disabled. +configure: libmemkind not found; disabling. +configure: compiler appears to not support #pragma omp simd. +configure: the BLAS compatibility layer is enabled. +configure: the CBLAS compatibility layer is disabled. +configure: mixed datatype support is enabled. +configure: mixed datatype optimizations requiring extra memory are enabled. +configure: small matrix handling is enabled. +configure: the BLIS API integer size is automatically determined. +configure: the BLAS/CBLAS API integer size is 32-bit. +configure: configuring for conventional gemm implementation. +configure: creating ./config.mk from ./build/config.mk.in +configure: creating ./bli_config.h from ./build/bli_config.h.in +configure: creating ./obj/power9 +configure: creating ./obj/power9/config/power9 +configure: creating ./obj/power9/kernels/power9 +configure: creating ./obj/power9/ref_kernels/power9 +configure: creating ./obj/power9/frame +configure: creating ./obj/power9/blastest +configure: creating ./obj/power9/testsuite +configure: creating ./lib/power9 +configure: creating ./include/power9 +configure: mirroring ./config/power9 to ./obj/power9/config/power9 +configure: mirroring ./kernels/power9 to ./obj/power9/kernels/power9 +configure: mirroring ./ref_kernels to ./obj/power9/ref_kernels +configure: mirroring ./ref_kernels to ./obj/power9/ref_kernels/power9 +configure: mirroring ./frame to ./obj/power9/frame +configure: creating makefile fragments in ./obj/power9/config/power9 +configure: creating makefile fragments in ./obj/power9/kernels/power9 +configure: creating makefile fragments in ./obj/power9/ref_kernels +configure: creating makefile fragments in ./obj/power9/frame +configure: configured to build within top-level directory of source distribution. +CONFIGURE DONE diff --git a/testsuite/old/jobscripts/cfig.sh b/testsuite/old/jobscripts/cfig.sh new file mode 100755 index 0000000000..b8927d7a61 --- /dev/null +++ b/testsuite/old/jobscripts/cfig.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd ~/blis +./configure power9 +echo "CONFIGURE DONE" diff --git a/testsuite/old/jobscripts/jb-cfig.sh b/testsuite/old/jobscripts/jb-cfig.sh new file mode 100644 index 0000000000..493fb7e703 --- /dev/null +++ b/testsuite/old/jobscripts/jb-cfig.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# execute in the general partition +#SBATCH --partition=general + +# execute with 40 processes/tasks +#SBATCH --ntasks=1 + +# maximum time is 30 minutes +#SBATCH --time=00:30:00 + +# job name is my_job +#SBATCH --job-name=blis + +# send email for status updates +#SBATCH --mail-type=ALL,TIME_LIMIT +#SBATCH --mail-user=ntukanov + +# change default output file name +#SBATCH --output=cfig.out + +# load environment +module load gcc/8.2 + +# application execution +srun cfig.sh diff --git a/testsuite/old/jobscripts/jb-mk.sh b/testsuite/old/jobscripts/jb-mk.sh new file mode 100644 index 0000000000..b2b56d6656 --- /dev/null +++ b/testsuite/old/jobscripts/jb-mk.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# execute in the general partition +#SBATCH --partition=general + +# execute with 40 processes/tasks +#SBATCH --ntasks=1 + +# maximum time is 30 minutes +#SBATCH --time=00:30:00 + +# job name is my_job +#SBATCH --job-name=blis + +# send email for status updates +#SBATCH --mail-type=ALL,TIME_LIMIT +#SBATCH --mail-user=ntukanov + +# change default output file name +#SBATCH --output=mk.out + +# load environment +module load gcc/8.2 + +# application execution +srun mk.sh diff --git a/testsuite/old/jobscripts/jb-runtest.sh b/testsuite/old/jobscripts/jb-runtest.sh new file mode 100644 index 0000000000..502b35fb6e --- /dev/null +++ b/testsuite/old/jobscripts/jb-runtest.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# execute in the general partition +#SBATCH --partition=general + +# execute with 40 processes/tasks +#SBATCH --ntasks=1 + +# maximum time is 30 minutes +#SBATCH --time=00:30:00 + +# job name is my_job +#SBATCH --job-name=blis + +# send email for status updates +#SBATCH --mail-type=ALL,TIME_LIMIT +#SBATCH --mail-user=ntukanov + +# change default output file name +#SBATCH --output=runtest.out + +# load environment +module load gcc/8.2 + +# application execution +srun runtest.sh diff --git a/testsuite/old/jobscripts/mk.out b/testsuite/old/jobscripts/mk.out new file mode 100644 index 0000000000..df324b23e5 --- /dev/null +++ b/testsuite/old/jobscripts/mk.out @@ -0,0 +1,9 @@ +Removing flattened header files from include/power9 +Removing object files from ./obj/power9 +srun: Job step aborted: Waiting up to 32 seconds for job step to finish. +srun: got SIGCONT +slurmstepd: error: *** JOB 1155 ON lookout00 CANCELLED AT 2019-06-10T17:29:07 *** +srun: forcing job termination +slurmstepd: error: *** STEP 1155.0 ON lookout00 CANCELLED AT 2019-06-10T17:29:07 *** +make: *** [cleanlib] Terminated +srun: error: lookout00: task 0: Terminated diff --git a/testsuite/old/jobscripts/mk.sh b/testsuite/old/jobscripts/mk.sh new file mode 100755 index 0000000000..186ed9f258 --- /dev/null +++ b/testsuite/old/jobscripts/mk.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +cd ~/blis +make clean +make +echo "MAKE DONE" diff --git a/testsuite/old/jobscripts/runtest.sh b/testsuite/old/jobscripts/runtest.sh new file mode 100755 index 0000000000..1650d2b697 --- /dev/null +++ b/testsuite/old/jobscripts/runtest.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +cd ~/blis/testsuite +rm -rf test_libblis.out +make clean +make -j +./test_libblis.x > test_libblis.out +echo "TEST DONE" diff --git a/testsuite/src/test_addm.c b/testsuite/src/test_addm.c index 545f9387bd..f7c21b733d 100644 --- a/testsuite/src/test_addm.c +++ b/testsuite/src/test_addm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -275,7 +275,7 @@ void libblis_test_addm_check // // is functioning correctly if // - // normfv(y) - sqrt( absqsc( beta + conjx(alpha) ) * m * n ) + // normfm(y) - sqrt( absqsc( beta + conjx(alpha) ) * m * n ) // // is negligible. // diff --git a/testsuite/src/test_addm.h b/testsuite/src/test_addm.h index 815f5db859..0dbdbfa2ee 100644 --- a/testsuite/src/test_addm.h +++ b/testsuite/src/test_addm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_addv.c b/testsuite/src/test_addv.c index c394ea1d83..9e216ab4d7 100644 --- a/testsuite/src/test_addv.c +++ b/testsuite/src/test_addv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_addv.h b/testsuite/src/test_addv.h index 1b9982e31f..eba5a9220e 100644 --- a/testsuite/src/test_addv.h +++ b/testsuite/src/test_addv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_amaxv.c b/testsuite/src/test_amaxv.c index d89e08c85a..fd6bad5f7f 100644 --- a/testsuite/src/test_amaxv.c +++ b/testsuite/src/test_amaxv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -266,6 +266,11 @@ void libblis_test_amaxv_check //bli_obj_scalar_init_detached( BLIS_INT, &index ); //bli_amaxv( x, &index ); bli_getsc( index, &i_d, &junk ); i = i_d; + + // If x is length 0, then we can't access any elements, and so we + // return early with a good residual. + if ( bli_obj_vector_dim( x ) == 0 ) { *resid = 0.0; return; } + bli_acquire_vi( i, x, &chi_i ); bli_obj_scalar_init_detached( BLIS_INT, &index_test ); @@ -351,11 +356,18 @@ void PASTEMAC0(opname) \ \ void* buf_index = bli_obj_buffer_at_off( index ); \ \ +/* + FGVZ: Disabling this code since bli_amaxv_check() is supposed to be a + non-public API function, and therefore unavailable unless all symbols + are scheduled to be exported at configure-time (which is not currently + the default behavior). + if ( bli_error_checking_is_enabled() ) \ bli_amaxv_check( x, index ); \ +*/ \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(tname,_vft) f = \ PASTEMAC(opname,_qfp)( dt ); \ \ diff --git a/testsuite/src/test_amaxv.h b/testsuite/src/test_amaxv.h index 4c382593f1..46d87b37f4 100644 --- a/testsuite/src/test_amaxv.h +++ b/testsuite/src/test_amaxv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpbyv.c b/testsuite/src/test_axpbyv.c index 24ed4a5ce1..a82ff6e256 100644 --- a/testsuite/src/test_axpbyv.c +++ b/testsuite/src/test_axpbyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -296,7 +296,7 @@ void libblis_test_axpbyv_check // // is functioning correctly if // - // normf( y - ( beta * y_orig + alpha * conjx(x) ) ) + // normfv( y - ( beta * y_orig + alpha * conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_axpbyv.h b/testsuite/src/test_axpbyv.h index a8fcd2dfa9..9b318dba10 100644 --- a/testsuite/src/test_axpbyv.h +++ b/testsuite/src/test_axpbyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpy2v.c b/testsuite/src/test_axpy2v.c index a834aa6a36..eeebf15e73 100644 --- a/testsuite/src/test_axpy2v.c +++ b/testsuite/src/test_axpy2v.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -314,7 +314,7 @@ void libblis_test_axpy2v_check // // is functioning correctly if // - // normf( z - v ) + // normfv( z - v ) // // is negligible, where v contains z as computed by two calls to axpyv. // diff --git a/testsuite/src/test_axpy2v.h b/testsuite/src/test_axpy2v.h index dc465792da..c695a643bb 100644 --- a/testsuite/src/test_axpy2v.h +++ b/testsuite/src/test_axpy2v.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpyf.c b/testsuite/src/test_axpyf.c index 3bd18ca3ef..7a85b22123 100644 --- a/testsuite/src/test_axpyf.c +++ b/testsuite/src/test_axpyf.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -319,7 +319,7 @@ void libblis_test_axpyf_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where v contains y as computed by repeated calls to // axpyv. diff --git a/testsuite/src/test_axpyf.h b/testsuite/src/test_axpyf.h index 179acb9bb4..9dd1dadc29 100644 --- a/testsuite/src/test_axpyf.h +++ b/testsuite/src/test_axpyf.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpym.c b/testsuite/src/test_axpym.c index c79866104e..222fda33db 100644 --- a/testsuite/src/test_axpym.c +++ b/testsuite/src/test_axpym.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -289,7 +289,7 @@ void libblis_test_axpym_check // // is functioning correctly if // - // normf( y - ( y_orig + alpha * conjx(x) ) ) + // normfm( y - ( y_orig + alpha * conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_axpym.h b/testsuite/src/test_axpym.h index 29819640f6..632720284d 100644 --- a/testsuite/src/test_axpym.h +++ b/testsuite/src/test_axpym.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpyv.c b/testsuite/src/test_axpyv.c index ff0326fb8c..81d4f37706 100644 --- a/testsuite/src/test_axpyv.c +++ b/testsuite/src/test_axpyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -286,7 +286,7 @@ void libblis_test_axpyv_check // // is functioning correctly if // - // normf( y - ( y_orig + alpha * conjx(x) ) ) + // normfv( y - ( y_orig + alpha * conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_axpyv.h b/testsuite/src/test_axpyv.h index a5ce3ea032..c96a9096bb 100644 --- a/testsuite/src/test_axpyv.h +++ b/testsuite/src/test_axpyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copym.c b/testsuite/src/test_copym.c index 2532a50c7c..1aab1d287b 100644 --- a/testsuite/src/test_copym.c +++ b/testsuite/src/test_copym.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copym.h b/testsuite/src/test_copym.h index 2a876ea68b..560de0e9a6 100644 --- a/testsuite/src/test_copym.h +++ b/testsuite/src/test_copym.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copyv.c b/testsuite/src/test_copyv.c index 70cb80f232..4350e95ee6 100644 --- a/testsuite/src/test_copyv.c +++ b/testsuite/src/test_copyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copyv.h b/testsuite/src/test_copyv.h index 1d413f75ba..2beb3212d0 100644 --- a/testsuite/src/test_copyv.h +++ b/testsuite/src/test_copyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotaxpyv.c b/testsuite/src/test_dotaxpyv.c index cf5563ec95..391c119bbd 100644 --- a/testsuite/src/test_dotaxpyv.c +++ b/testsuite/src/test_dotaxpyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -345,7 +345,7 @@ void libblis_test_dotaxpyv_check // // and // - // normf( z - z_temp ) + // normfv( z - z_temp ) // // are negligible, where rho_temp and z_temp contain rho and z as // computed by dotv and axpyv, respectively. diff --git a/testsuite/src/test_dotaxpyv.h b/testsuite/src/test_dotaxpyv.h index 7421339707..ce82227f49 100644 --- a/testsuite/src/test_dotaxpyv.h +++ b/testsuite/src/test_dotaxpyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotv.c b/testsuite/src/test_dotv.c index ff9cd2b59c..347ce9e620 100644 --- a/testsuite/src/test_dotv.c +++ b/testsuite/src/test_dotv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -278,7 +278,7 @@ void libblis_test_dotv_check // // is functioning correctly if // - // sqrtsc( rho.real ) - normf( x ) + // sqrtsc( rho.real ) - normfv( x ) // // and // diff --git a/testsuite/src/test_dotv.h b/testsuite/src/test_dotv.h index d1b3b0e29f..2f000128b1 100644 --- a/testsuite/src/test_dotv.h +++ b/testsuite/src/test_dotv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxaxpyf.c b/testsuite/src/test_dotxaxpyf.c index e85edff171..a2c3ef3e94 100644 --- a/testsuite/src/test_dotxaxpyf.c +++ b/testsuite/src/test_dotxaxpyf.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,10 +41,10 @@ static char* op_str = "dotxaxpyf"; static char* o_types = "mvvvv"; // A w x y z static char* p_types = "cccc"; // conjat conja conjw conjx -static thresh_t thresh[BLIS_NUM_FP_TYPES] = { { 1e-04, 1e-05 }, // warn, pass for s - { 1e-04, 1e-05 }, // warn, pass for c - { 1e-13, 1e-14 }, // warn, pass for d - { 1e-13, 1e-14 } }; // warn, pass for z +static thresh_t thresh[BLIS_NUM_FP_TYPES] = { { 5e-04, 5e-05 }, // warn, pass for s + { 5e-04, 5e-05 }, // warn, pass for c + { 5e-13, 5e-14 }, // warn, pass for d + { 5e-13, 5e-14 } }; // warn, pass for z // Local prototypes. void libblis_test_dotxaxpyf_deps @@ -366,11 +366,11 @@ void libblis_test_dotxaxpyf_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // and // - // normf( z - q ) + // normfv( z - q ) // // are negligible, where v and q contain y and z as computed by repeated // calls to dotxv and axpyv, respectively. diff --git a/testsuite/src/test_dotxaxpyf.h b/testsuite/src/test_dotxaxpyf.h index 72b93a637e..6bfcd2655e 100644 --- a/testsuite/src/test_dotxaxpyf.h +++ b/testsuite/src/test_dotxaxpyf.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxf.c b/testsuite/src/test_dotxf.c index d73fd0609e..8a1eca4eba 100644 --- a/testsuite/src/test_dotxf.c +++ b/testsuite/src/test_dotxf.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -324,7 +324,7 @@ void libblis_test_dotxf_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where v contains y as computed by repeated calls to // dotxv. diff --git a/testsuite/src/test_dotxf.h b/testsuite/src/test_dotxf.h index 8940e6a759..06cac584e3 100644 --- a/testsuite/src/test_dotxf.h +++ b/testsuite/src/test_dotxf.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxv.c b/testsuite/src/test_dotxv.c index 76a47a08dc..da42e6ae4d 100644 --- a/testsuite/src/test_dotxv.c +++ b/testsuite/src/test_dotxv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -304,7 +304,7 @@ void libblis_test_dotxv_check // // is functioning correctly if // - // sqrtsc( rho.real ) - sqrtsc( alpha ) * normf( x ) + // sqrtsc( rho.real ) - sqrtsc( alpha ) * normfv( x ) // // and // diff --git a/testsuite/src/test_dotxv.h b/testsuite/src/test_dotxv.h index 02009b5a9a..a3e2ca48f2 100644 --- a/testsuite/src/test_dotxv.h +++ b/testsuite/src/test_dotxv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index 24ecf6d618..65f910f9b1 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -236,11 +236,11 @@ void libblis_test_gemm_experiment libblis_test_mobj_create( params, datatype, transa, sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], k, n, &b ); + sc_str[2], k, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -254,6 +254,12 @@ void libblis_test_gemm_experiment bli_setsc( 0.9, 1.0, &beta ); } + #if 0 + //bli_setm( &BLIS_ONE, &a ); + bli_setsc( 1.0, 0.0, &alpha ); + bli_setsc( 1.0, 0.0, &beta ); + #endif + // Randomize A, B, and C, and save C. libblis_test_mobj_randomize( params, TRUE, &a ); libblis_test_mobj_randomize( params, TRUE, &b ); @@ -349,13 +355,13 @@ void libblis_test_gemm_md // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, dt_a, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, dt_b, transb, - sc_str[1], k, n, &b ); + sc_str[2], k, n, &b ); libblis_test_mobj_create( params, dt_c, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, dt_c, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // For mixed-precision, set the computation precision of C. if ( params->mixed_precision ) @@ -400,17 +406,7 @@ void libblis_test_gemm_md time = bli_clock(); -#if 0 -bli_printm( "a", &a, "%5.2f", "" ); -bli_printm( "b", &b, "%5.2f", "" ); -bli_printm( "c", &c, "%5.2f", "" ); -bli_printm( "alpha", &alpha, "%5.2f", "" ); -bli_printm( "beta", &beta, "%5.2f", "" ); -#endif libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c ); -#if 0 -bli_printm( "c after", &c, "%5.2f", "" ); -#endif time_min = bli_clock_min_diff( time_min, time ); } @@ -449,16 +445,25 @@ void libblis_test_gemm_impl { case BLIS_TEST_SEQ_FRONT_END: #if 0 +//bli_printm( "alpha", alpha, "%5.2f", "" ); +//bli_printm( "beta", beta, "%5.2f", "" ); +if ( bli_obj_dt( c ) == BLIS_DCOMPLEX ) +{ bli_printm( "a", a, "%5.2f", "" ); bli_printm( "b", b, "%5.2f", "" ); bli_printm( "c", c, "%5.2f", "" ); -bli_printm( "alpha", alpha, "%5.2f", "" ); -bli_printm( "beta", beta, "%5.2f", "" ); +} #endif +//if ( bli_obj_length( b ) == 16 && +// bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) +//bli_printm( "c before", c, "%6.3f", "" ); bli_gemm( alpha, a, b, beta, c ); + //bls_gemm( alpha, a, b, beta, c ); #if 0 -bli_printm( "c after", c, "%5.2f", "" ); +if ( bli_obj_dt( c ) == BLIS_DCOMPLEX ) +bli_printm( "c after", c, "%6.3f", "" ); #endif +//bli_printm( "c after", c, "%5.2f", "" ); break; default: @@ -617,7 +622,7 @@ void libblis_test_gemm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // @@ -660,14 +665,14 @@ double libblis_test_gemm_flops obj_t* c ) { - bool_t a_is_real = bli_obj_is_real( a ); - bool_t a_is_complex = bli_obj_is_complex( a ); + bool a_is_real = bli_obj_is_real( a ); + bool a_is_complex = bli_obj_is_complex( a ); - bool_t b_is_real = bli_obj_is_real( b ); - bool_t b_is_complex = bli_obj_is_complex( b ); + bool b_is_real = bli_obj_is_real( b ); + bool b_is_complex = bli_obj_is_complex( b ); - bool_t c_is_real = bli_obj_is_real( c ); - bool_t c_is_complex = bli_obj_is_complex( c ); + bool c_is_real = bli_obj_is_real( c ); + bool c_is_complex = bli_obj_is_complex( c ); double m = ( double )bli_obj_length( c ); double n = ( double )bli_obj_width( c ); diff --git a/testsuite/src/test_gemm.h b/testsuite/src/test_gemm.h index f1c41bb950..78364bc249 100644 --- a/testsuite/src/test_gemm.h +++ b/testsuite/src/test_gemm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemm_ukr.c b/testsuite/src/test_gemm_ukr.c index 616532491d..d37005b285 100644 --- a/testsuite/src/test_gemm_ukr.c +++ b/testsuite/src/test_gemm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -169,7 +169,6 @@ void libblis_test_gemm_ukr_experiment num_t datatype; dim_t m, n, k; - inc_t ldap, ldbp; char sc_a = 'c'; char sc_b = 'r'; @@ -194,11 +193,6 @@ void libblis_test_gemm_ukr_experiment m = bli_cntx_get_blksz_def_dt( datatype, BLIS_MR, cntx ); n = bli_cntx_get_blksz_def_dt( datatype, BLIS_NR, cntx ); - // Also query PACKMR and PACKNR as the leading dimensions to ap and bp, - // respectively. - ldap = bli_cntx_get_blksz_max_dt( datatype, BLIS_MR, cntx ); - ldbp = bli_cntx_get_blksz_max_dt( datatype, BLIS_NR, cntx ); - // Store the register blocksizes so that the driver can retrieve the // values later when printing results. op->dim_aux[0] = m; @@ -237,7 +231,13 @@ void libblis_test_gemm_ukr_experiment libblis_test_mobj_randomize( params, TRUE, &c ); bli_copym( &c, &c_save ); -#if 0 + rntm_t rntm; + bli_rntm_init( &rntm ); + bli_pba_rntm_set_pba( &rntm ); + + // Transpose B to B^T for packing. + bli_obj_induce_trans( &b ); + // Create pack objects for a and b, and pack them to ap and bp, // respectively. cntl_t* cntl_a = libblis_test_pobj_create @@ -248,46 +248,26 @@ void libblis_test_gemm_ukr_experiment BLIS_PACKED_ROW_PANELS, BLIS_BUFFER_FOR_A_BLOCK, &a, &ap, - cntx + cntx, + &rntm ); cntl_t* cntl_b = libblis_test_pobj_create ( - BLIS_KR, BLIS_NR, + BLIS_KR, BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, BLIS_BUFFER_FOR_B_PANEL, &b, &bp, - cntx + cntx, + &rntm ); -#endif - - // Create the packed objects. Use packmr and packnr as the leading - // dimensions of ap and bp, respectively. - bli_obj_create( datatype, m, k, 1, ldap, &ap ); - bli_obj_create( datatype, k, n, ldbp, 1, &bp ); - - // Set up the objects for packing. Calling packm_init_pack() does everything - // except checkout a memory pool block and save its address to the obj_t's. - // However, it does overwrite the buffer field of packed object with that of - // the source object. So, we have to save the buffer address that was - // allocated. - void* buf_ap = bli_obj_buffer( &ap ); - void* buf_bp = bli_obj_buffer( &bp ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_ROW_PANELS, - BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_MR, BLIS_KR, &a, &ap, cntx ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, - BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_KR, BLIS_NR, &b, &bp, cntx ); - bli_obj_set_buffer( buf_ap, &ap ); - bli_obj_set_buffer( buf_bp, &bp ); - - // Pack the data from the source objects. - bli_packm_blk_var1( &a, &ap, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - - // Repeat the experiment n_repeats times and record results. + + // Transpose B^T back to B and Bp^T back to Bp. + bli_obj_induce_trans( &b ); + bli_obj_induce_trans( &bp ); + + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { bli_copym( &c_save, &c ); @@ -311,12 +291,10 @@ void libblis_test_gemm_ukr_experiment // Zero out performance and residual if output matrix is empty. libblis_test_check_empty_problem( &c, perf, resid ); -#if 0 // Free the control tree nodes and release their cached mem_t entries - // back to the memory broker. - bli_cntl_free( cntl_a, &BLIS_PACKM_SINGLE_THREADED ); - bli_cntl_free( cntl_b, &BLIS_PACKM_SINGLE_THREADED ); -#endif + // back to the pba. + bli_cntl_free( &rntm, cntl_a, &BLIS_PACKM_SINGLE_THREADED ); + bli_cntl_free( &rntm, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); // Free the test objects. bli_obj_free( &a ); @@ -390,7 +368,7 @@ void libblis_test_gemm_ukr_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_gemm_ukr.h b/testsuite/src/test_gemm_ukr.h index d20c47c627..cd09ef3f69 100644 --- a/testsuite/src/test_gemm_ukr.h +++ b/testsuite/src/test_gemm_ukr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemmt.c b/testsuite/src/test_gemmt.c new file mode 100644 index 0000000000..3b7b08748a --- /dev/null +++ b/testsuite/src/test_gemmt.c @@ -0,0 +1,398 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "test_libblis.h" + + +// Static variables. +static char* op_str = "gemmt"; +static char* o_types = "mmm"; // a b c +static char* p_types = "uhh"; // uploc transa transb +static thresh_t thresh[BLIS_NUM_FP_TYPES] = { { 1e-04, 1e-05 }, // warn, pass for s + { 1e-04, 1e-05 }, // warn, pass for c + { 1e-13, 1e-14 }, // warn, pass for d + { 1e-13, 1e-14 } }; // warn, pass for z + +// Local prototypes. +void libblis_test_gemmt_deps + ( + thread_data_t* tdata, + test_params_t* params, + test_op_t* op + ); + +void libblis_test_gemmt_experiment + ( + test_params_t* params, + test_op_t* op, + iface_t iface, + char* dc_str, + char* pc_str, + char* sc_str, + unsigned int p_cur, + double* perf, + double* resid + ); + +void libblis_test_gemmt_impl + ( + iface_t iface, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ); + +void libblis_test_gemmt_check + ( + test_params_t* params, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + obj_t* c_orig, + double* resid + ); + + + +void libblis_test_gemmt_deps + ( + thread_data_t* tdata, + test_params_t* params, + test_op_t* op + ) +{ + libblis_test_randv( tdata, params, &(op->ops->randv) ); + libblis_test_randm( tdata, params, &(op->ops->randm) ); + libblis_test_setv( tdata, params, &(op->ops->setv) ); + libblis_test_normfv( tdata, params, &(op->ops->normfv) ); + libblis_test_subv( tdata, params, &(op->ops->subv) ); + libblis_test_scalv( tdata, params, &(op->ops->scalv) ); + libblis_test_copym( tdata, params, &(op->ops->copym) ); + libblis_test_scalm( tdata, params, &(op->ops->scalm) ); + libblis_test_gemv( tdata, params, &(op->ops->gemv) ); + libblis_test_gemm( tdata, params, &(op->ops->gemm) ); +} + + + +void libblis_test_gemmt + ( + thread_data_t* tdata, + test_params_t* params, + test_op_t* op + ) +{ + + // Return early if this test has already been done. + if ( libblis_test_op_is_done( op ) ) return; + + // Return early if operation is disabled. + if ( libblis_test_op_is_disabled( op ) || + libblis_test_l3_is_disabled( op ) ) return; + + // Call dependencies first. + if ( TRUE ) libblis_test_gemmt_deps( tdata, params, op ); + + // Execute the test driver for each implementation requested. + //if ( op->front_seq == ENABLE ) + { + libblis_test_op_driver( tdata, + params, + op, + BLIS_TEST_SEQ_FRONT_END, + op_str, + p_types, + o_types, + thresh, + libblis_test_gemmt_experiment ); + } +} + + + +void libblis_test_gemmt_experiment + ( + test_params_t* params, + test_op_t* op, + iface_t iface, + char* dc_str, + char* pc_str, + char* sc_str, + unsigned int p_cur, + double* perf, + double* resid + ) +{ + unsigned int n_repeats = params->n_repeats; + unsigned int i; + + double time_min = DBL_MAX; + double time; + + num_t datatype; + + dim_t m, k; + + uplo_t uploc; + trans_t transa; + trans_t transb; + + obj_t alpha, a, b, beta, c; + obj_t c_save; + + + // Use the datatype of the first char in the datatype combination string. + bli_param_map_char_to_blis_dt( dc_str[0], &datatype ); + + // Map the dimension specifier to actual dimensions. + m = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur ); + k = libblis_test_get_dim_from_prob_size( op->dim_spec[1], p_cur ); + + // Map parameter characters to BLIS constants. + bli_param_map_char_to_blis_uplo( pc_str[0], &uploc ); + bli_param_map_char_to_blis_trans( pc_str[1], &transa ); + bli_param_map_char_to_blis_trans( pc_str[2], &transb ); + + // Create test scalars. + bli_obj_scalar_init_detached( datatype, &alpha ); + bli_obj_scalar_init_detached( datatype, &beta ); + + // Create test operands (vectors and/or matrices). + libblis_test_mobj_create( params, datatype, transa, + sc_str[1], m, k, &a ); + libblis_test_mobj_create( params, datatype, transb, + sc_str[2], k, m, &b ); + libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c ); + libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c_save ); + + // Set alpha and beta. + if ( bli_obj_is_real( &c ) ) + { + bli_setsc( 1.2, 0.0, &alpha ); + bli_setsc( 0.9, 0.0, &beta ); + } + else + { + bli_setsc( 1.2, 0.8, &alpha ); + bli_setsc( 0.9, 1.0, &beta ); + } + + // Randomize A and B. + libblis_test_mobj_randomize( params, TRUE, &a ); + libblis_test_mobj_randomize( params, TRUE, &b ); + +//bli_setm( &BLIS_ONE, &a ); +//bli_setm( &BLIS_ONE, &b ); +//bli_setsc( 1.0, 0.0, &alpha ); +//bli_setsc( 0.0, 0.0, &beta ); + + // Set the uplo property of C. + bli_obj_set_uplo( uploc, &c ); + + // Randomize C, make it densely symmetric, and zero the unstored triangle + // to ensure the implementation reads only from the stored region. + libblis_test_mobj_randomize( params, TRUE, &c ); + bli_mksymm( &c ); + bli_mktrim( &c ); + + // Save C and set its uplo property. + bli_setm( &BLIS_ZERO, &c_save ); + bli_obj_set_uplo( uploc, &c_save ); + bli_copym( &c, &c_save ); + + // Apply the parameters. + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + // Repeat the experiment n_repeats times and record results. + for ( i = 0; i < n_repeats; ++i ) + { + bli_copym( &c_save, &c ); + + time = bli_clock(); + + libblis_test_gemmt_impl( iface, &alpha, &a, &b, &beta, &c ); + + time_min = bli_clock_min_diff( time_min, time ); + } + + // Estimate the performance of the best experiment repeat. + *perf = ( 1.0 * m * m * k ) / time_min / FLOPS_PER_UNIT_PERF; + if ( bli_obj_is_complex( &c ) ) *perf *= 4.0; + + // Perform checks. + libblis_test_gemmt_check( params, &alpha, &a, &b, &beta, &c, &c_save, resid ); + + // Zero out performance and residual if output matrix is empty. + libblis_test_check_empty_problem( &c, perf, resid ); + + // Free the test objects. + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); +} + + + +void libblis_test_gemmt_impl + ( + iface_t iface, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + switch ( iface ) + { + case BLIS_TEST_SEQ_FRONT_END: +#if 0 +//bli_printm( "alpha", alpha, "%5.2f", "" ); +//bli_printm( "beta", beta, "%5.2f", "" ); +bli_printm( "a", a, "%5.2f", "" ); +bli_printm( "b", b, "%5.2f", "" ); +bli_printm( "c", c, "%5.2f", "" ); +#endif +//if ( bli_obj_length( b ) == 16 && +// bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) +//bli_printm( "c before", c, "%6.3f", "" ); + bli_gemmt( alpha, a, b, beta, c ); +#if 0 +//if ( bli_obj_length( c ) == 12 && +// bli_obj_stor3_from_strides( c, a, b ) == BLIS_RRR ) +bli_printm( "c after", c, "%5.2f", "" ); +#endif + break; + + default: + libblis_test_printf_error( "Invalid interface type.\n" ); + } +} + + + +void libblis_test_gemmt_check + ( + test_params_t* params, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + obj_t* c_orig, + double* resid + ) +{ + num_t dt = bli_obj_dt( c ); + num_t dt_real = bli_obj_dt_proj_to_real( c ); + uplo_t uploc = bli_obj_uplo( c ); + + dim_t m = bli_obj_length( c ); + //dim_t k = bli_obj_width_after_trans( a ); + + obj_t norm; + obj_t t, v, q, z; + + double junk; + + // + // Pre-conditions: + // - a is randomized. + // - b is randomized. + // - c_orig is randomized. + // Note: + // - alpha and beta should have non-zero imaginary components in the + // complex cases in order to more fully exercise the implementation. + // + // Under these conditions, we assume that the implementation for + // + // C := beta * C_orig + alpha * transa(A) * transb(B) + // + // is functioning correctly if + // + // normfv( v - z ) + // + // is negligible, where + // + // v = C * t + // z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t + // = beta * C_orig * t + alpha * transa(A) * transb(B) * t + // = beta * C_orig * t + alpha * uplo(Q) * t + // = beta * C_orig * t + z + // + + bli_obj_scalar_init_detached( dt_real, &norm ); + + bli_obj_create( dt, m, 1, 0, 0, &t ); + bli_obj_create( dt, m, 1, 0, 0, &v ); + bli_obj_create( dt, m, 1, 0, 0, &z ); + + bli_obj_create( dt, m, m, 0, 0, &q ); + bli_obj_set_uplo( uploc, &q ); + + libblis_test_vobj_randomize( params, TRUE, &t ); + + bli_gemv( &BLIS_ONE, c, &t, &BLIS_ZERO, &v ); + + bli_gemm( &BLIS_ONE, a, b, &BLIS_ZERO, &q ); +#if 1 + bli_mktrim( &q ); + bli_gemv( alpha, &q, &t, &BLIS_ZERO, &z ); +#else + bli_obj_set_struc( BLIS_TRIANGULAR, &q ); + bli_copyv( &t, &z ); + bli_trmv( alpha, &q, &z ); +#endif + bli_gemv( beta, c_orig, &t, &BLIS_ONE, &z ); + + bli_subv( &z, &v ); + bli_normfv( &v, &norm ); + bli_getsc( &norm, resid, &junk ); + + bli_obj_free( &t ); + bli_obj_free( &v ); + bli_obj_free( &z ); + bli_obj_free( &q ); +} + diff --git a/frame/3/syrk/bli_syrk_front.h b/testsuite/src/test_gemmt.h similarity index 91% rename from frame/3/syrk/bli_syrk_front.h rename to testsuite/src/test_gemmt.h index 28d1e13f61..5468bf48dc 100644 --- a/frame/3/syrk/bli_syrk_front.h +++ b/testsuite/src/test_gemmt.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,13 +33,10 @@ */ -void bli_syrk_front +void libblis_test_gemmt ( - obj_t* alpha, - obj_t* a, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl + thread_data_t* tdata, + test_params_t* params, + test_op_t* op ); + diff --git a/testsuite/src/test_gemmtrsm_ukr.c b/testsuite/src/test_gemmtrsm_ukr.c index 6d2f028d23..48fcb78db7 100644 --- a/testsuite/src/test_gemmtrsm_ukr.c +++ b/testsuite/src/test_gemmtrsm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -283,7 +283,10 @@ void libblis_test_gemmtrsm_ukr_experiment bli_copym( &b11, &c11 ); bli_copym( &c11, &c11_save ); -#if 0 + rntm_t rntm; + bli_rntm_init( &rntm ); + bli_pba_rntm_set_pba( &rntm ); + // Create pack objects for a and b, and pack them to ap and bp, // respectively. cntl_t* cntl_a = libblis_test_pobj_create @@ -294,40 +297,9 @@ void libblis_test_gemmtrsm_ukr_experiment BLIS_PACKED_ROW_PANELS, BLIS_BUFFER_FOR_A_BLOCK, &a, &ap, - &cntx - ); - cntl_t* cntl_b = libblis_test_pobj_create - ( - BLIS_MR, - BLIS_NR, - BLIS_NO_INVERT_DIAG, - BLIS_PACKED_COL_PANELS, - BLIS_BUFFER_FOR_B_PANEL, - &b, &bp, - &cntx + cntx, + &rntm ); -#endif - - // Create the packed objects. Use packmr and packnr as the leading - // dimensions of ap and bp, respectively. - bli_obj_create( datatype, m, k+m, 1, ldap, &ap ); - bli_obj_create( datatype, k+m, n, ldbp, 1, &bp ); - - // Set up the objects for packing. Calling packm_init_pack() does everything - // except checkout a memory pool block and save its address to the obj_t's. - // However, it does overwrite the buffer field of packed object with that of - // the source object. So, we have to save the buffer address that was - // allocated. - void* buf_ap = bli_obj_buffer( &ap ); - void* buf_bp = bli_obj_buffer( &bp ); - bli_packm_init_pack( BLIS_INVERT_DIAG, BLIS_PACKED_ROW_PANELS, - BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_MR, BLIS_KR, &a, &ap, cntx ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, - BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_KR, BLIS_NR, &b, &bp, cntx ); - bli_obj_set_buffer( buf_ap, &ap ); - bli_obj_set_buffer( buf_bp, &bp ); // Set the diagonal offset of ap. if ( bli_is_lower( uploa ) ) { bli_obj_set_diag_offset( k, &ap ); } @@ -338,32 +310,45 @@ void libblis_test_gemmtrsm_ukr_experiment // to know how to initialize the subpartitions. bli_obj_set_uplo( uploa, &ap ); - // Pack the data from the source objects. - bli_packm_blk_var1( &a, &ap, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - - // Create subpartitions from the a and b panels. - bli_gemmtrsm_ukr_make_subparts( k, &ap, &bp, - &a1xp, &a11p, &bx1p, &b11p ); - - // Set the uplo field of a11p since the default for packed objects is - // BLIS_DENSE, and the _ukernel() wrapper needs this information to - // know which set of micro-kernels (lower or upper) to choose from. - bli_obj_set_uplo( uploa, &a11p ); - #if 0 bli_printm( "a", &a, "%5.2f", "" ); bli_printm( "ap", &ap, "%5.2f", "" ); #endif - // Repeat the experiment n_repeats times and record results. + cntl_t* cntl_b = NULL; + + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { bli_copym( &c11_save, &c11 ); - // Re-pack (restore) the contents of b to bp. - //bli_packm_blk_var1( &b, &bp, &cntx, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); + // Transpose B to B^T for packing. + bli_obj_induce_trans( &b ); + + cntl_b = libblis_test_pobj_create + ( + BLIS_NR, + BLIS_MR, + BLIS_NO_INVERT_DIAG, + BLIS_PACKED_COL_PANELS, + BLIS_BUFFER_FOR_B_PANEL, + &b, &bp, + cntx, + &rntm + ); + + // Transpose B^T back to B and Bp^T back to Bp. + bli_obj_induce_trans( &b ); + bli_obj_induce_trans( &bp ); + + // Create subpartitions from the a and b panels. + bli_gemmtrsm_ukr_make_subparts( k, &ap, &bp, + &a1xp, &a11p, &bx1p, &b11p ); + + // Set the uplo field of a11p since the default for packed objects is + // BLIS_DENSE, and the _ukernel() wrapper needs this information to + // know which set of micro-kernels (lower or upper) to choose from. + bli_obj_set_uplo( uploa, &a11p ); time = bli_clock(); @@ -372,12 +357,43 @@ bli_printm( "ap", &ap, "%5.2f", "" ); cntx ); time_min = bli_clock_min_diff( time_min, time ); + + // On the last pass, we must keep the packed B buffer checked out in order + // to perform the correctness check later. + if ( i < n_repeats - 1 ) + { + // Free the control tree nodes and release their cached mem_t entries + // back to the memory broker. + bli_cntl_free( &rntm, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); + } } // Estimate the performance of the best experiment repeat. *perf = ( 2.0 * m * n * k + 1.0 * m * m * n ) / time_min / FLOPS_PER_UNIT_PERF; if ( bli_obj_is_complex( &b ) ) *perf *= 4.0; + // A hack to support subconfigs such as power9, which duplicate/broadcast + // more than one stored element per logical element in the packed copy of + // B. We assume that the ratio ldbp/n gives us the duplication factor used + // within B while the ratio ldap/m gives us the duplication factor used + // within A (not entirely a safe assumption, though I think it holds for + // all gemm ukernels currently supported within BLIS). This duplication + // factor must be used as the column stride of B (or the row stride of A) + // in order for the bli_gemmv() operation (called within the + // libblis_test_gemmtrsm_ukr_check()) to operate properly. + if ( ldbp / n > 1 ) + { + const dim_t bfac = ldbp / n; + bli_obj_set_col_stride( bfac, &b11p ); + bli_obj_set_col_stride( bfac, &bx1p ); + } + if ( ldap / m > 1 ) + { + const dim_t bfac = ldap / m; + bli_obj_set_row_stride( bfac, &a11p ); + bli_obj_set_row_stride( bfac, &a1xp ); + } + // Perform checks. libblis_test_gemmtrsm_ukr_check( params, side, &alpha, &a1xp, &a11p, &bx1p, &b11p, &c11, &c11_save, resid ); @@ -385,12 +401,11 @@ bli_printm( "ap", &ap, "%5.2f", "" ); // Zero out performance and residual if output matrix is empty. //libblis_test_check_empty_problem( &c11, perf, resid ); -#if 0 // Free the control tree nodes and release their cached mem_t entries - // back to the memory broker. - bli_cntl_free( cntl_a, &BLIS_PACKM_SINGLE_THREADED ); - bli_cntl_free( cntl_b, &BLIS_PACKM_SINGLE_THREADED ); -#endif + // back to the pba. + bli_cntl_free( &rntm, cntl_a, &BLIS_PACKM_SINGLE_THREADED ); + if ( cntl_b ) + bli_cntl_free( &rntm, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); // Free the test objects. bli_obj_free( &a_big ); @@ -465,7 +480,7 @@ void libblis_test_gemmtrsm_ukr_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_gemmtrsm_ukr.h b/testsuite/src/test_gemmtrsm_ukr.h index 8bf52c4eba..5fd3cc0ba0 100644 --- a/testsuite/src/test_gemmtrsm_ukr.h +++ b/testsuite/src/test_gemmtrsm_ukr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemv.c b/testsuite/src/test_gemv.c index aa10764b0a..e6090e1c5b 100644 --- a/testsuite/src/test_gemv.c +++ b/testsuite/src/test_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -324,7 +324,7 @@ void libblis_test_gemv_check // // is functioning correctly if // - // normf( y - z ) + // normfv( y - z ) // // is negligible, where // diff --git a/testsuite/src/test_gemv.h b/testsuite/src/test_gemv.h index 2de09095eb..8e7284486a 100644 --- a/testsuite/src/test_gemv.h +++ b/testsuite/src/test_gemv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_ger.c b/testsuite/src/test_ger.c index c611c46614..b44fe6ba64 100644 --- a/testsuite/src/test_ger.c +++ b/testsuite/src/test_ger.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -303,7 +303,7 @@ void libblis_test_ger_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_ger.h b/testsuite/src/test_ger.h index f053a73b9a..5b75babe60 100644 --- a/testsuite/src/test_ger.h +++ b/testsuite/src/test_ger.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_hemm.c b/testsuite/src/test_hemm.c index 3e27325007..cac5aa73a0 100644 --- a/testsuite/src/test_hemm.c +++ b/testsuite/src/test_hemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -202,13 +202,13 @@ void libblis_test_hemm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, n, &b ); + sc_str[2], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -287,8 +287,6 @@ void libblis_test_hemm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_hemm( side, alpha, a, b, beta, c ); - //bli_hemm4m( side, alpha, a, b, beta, c ); - //bli_hemm3m( side, alpha, a, b, beta, c ); break; default: @@ -338,7 +336,7 @@ void libblis_test_hemm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_hemm.h b/testsuite/src/test_hemm.h index f3ff7b90f5..7db76afa1e 100644 --- a/testsuite/src/test_hemm.h +++ b/testsuite/src/test_hemm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_hemv.c b/testsuite/src/test_hemv.c index 17204102e3..02e205392b 100644 --- a/testsuite/src/test_hemv.c +++ b/testsuite/src/test_hemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -322,7 +322,7 @@ void libblis_test_hemv_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where // diff --git a/testsuite/src/test_hemv.h b/testsuite/src/test_hemv.h index 701a79030e..e522690d1e 100644 --- a/testsuite/src/test_hemv.h +++ b/testsuite/src/test_hemv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her.c b/testsuite/src/test_her.c index c5ca4b14d3..c122f6ce56 100644 --- a/testsuite/src/test_her.c +++ b/testsuite/src/test_her.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -301,7 +301,7 @@ void libblis_test_her_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_her.h b/testsuite/src/test_her.h index fed1e45ef4..a6aaa55b47 100644 --- a/testsuite/src/test_her.h +++ b/testsuite/src/test_her.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her2.c b/testsuite/src/test_her2.c index 896497b4ea..1ed6b3bb9e 100644 --- a/testsuite/src/test_her2.c +++ b/testsuite/src/test_her2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -311,7 +311,7 @@ void libblis_test_her2_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_her2.h b/testsuite/src/test_her2.h index 6e4d26a64e..c2711cfb11 100644 --- a/testsuite/src/test_her2.h +++ b/testsuite/src/test_her2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her2k.c b/testsuite/src/test_her2k.c index 067b1c9738..59bbaf5f1d 100644 --- a/testsuite/src/test_her2k.c +++ b/testsuite/src/test_her2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -195,13 +195,13 @@ void libblis_test_her2k_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, k, &b ); + sc_str[2], m, k, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -285,8 +285,6 @@ void libblis_test_her2k_impl { case BLIS_TEST_SEQ_FRONT_END: bli_her2k( alpha, a, b, beta, c ); - //bli_her2k4m( alpha, a, b, beta, c ); - //bli_her2k3m( alpha, a, b, beta, c ); break; default: @@ -336,7 +334,7 @@ void libblis_test_her2k_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_her2k.h b/testsuite/src/test_her2k.h index 36382ed9cb..a481dac720 100644 --- a/testsuite/src/test_her2k.h +++ b/testsuite/src/test_her2k.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_herk.c b/testsuite/src/test_herk.c index 0db92b4d2d..bbb7be9228 100644 --- a/testsuite/src/test_herk.c +++ b/testsuite/src/test_herk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -192,11 +192,11 @@ void libblis_test_herk_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -276,8 +276,6 @@ void libblis_test_herk_impl { case BLIS_TEST_SEQ_FRONT_END: bli_herk( alpha, a, beta, c ); - //bli_herk4m( alpha, a, beta, c ); - //bli_herk3m( alpha, a, beta, c ); break; default: @@ -323,7 +321,7 @@ void libblis_test_herk_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_herk.h b/testsuite/src/test_herk.h index 6235cb7613..1702bd8b9b 100644 --- a/testsuite/src/test_herk.h +++ b/testsuite/src/test_herk.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 96c705c9a6..edab9796d2 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -121,6 +121,8 @@ void* libblis_test_thread_entry( void* tdata_void ) void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) { + err_t r_val; + // Query the total number of threads to simulate. size_t nt = ( size_t )params->n_app_threads; @@ -130,22 +132,22 @@ void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_pthread_t* pthread = bli_malloc_intl( sizeof( bli_pthread_t ) * nt ); + bli_pthread_t* pthread = bli_malloc_user( sizeof( bli_pthread_t ) * nt, &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - thread_data_t* tdata = bli_malloc_intl( sizeof( thread_data_t ) * nt ); + thread_data_t* tdata = bli_malloc_user( sizeof( thread_data_t ) * nt, &r_val ); // Allocate a mutex for the threads to share. - //bli_pthread_mutex_t* mutex = bli_malloc_intl( sizeof( bli_pthread_mutex_t ) ); + //bli_pthread_mutex_t* mutex = bli_malloc_user( sizeof( bli_pthread_mutex_t ) ); // Allocate a barrier for the threads to share. #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_pthread_barrier_t* barrier = bli_malloc_intl( sizeof( bli_pthread_barrier_t ) ); + bli_pthread_barrier_t* barrier = bli_malloc_user( sizeof( bli_pthread_barrier_t ), &r_val ); // Initialize the mutex. //bli_pthread_mutex_init( mutex, NULL ); @@ -191,18 +193,18 @@ void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_free_intl( pthread ); + bli_free_user( pthread ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_free_intl( tdata ); + bli_free_user( tdata ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - //bli_free_intl( mutex ); - bli_free_intl( barrier ); + //bli_free_user( mutex ); + bli_free_user( barrier ); } @@ -314,6 +316,7 @@ void libblis_test_level3_ukrs( thread_data_t* tdata, test_params_t* params, test void libblis_test_level3_ops( thread_data_t* tdata, test_params_t* params, test_ops_t* ops ) { libblis_test_gemm( tdata, params, &(ops->gemm) ); + libblis_test_gemmt( tdata, params, &(ops->gemmt) ); libblis_test_hemm( tdata, params, &(ops->hemm) ); libblis_test_herk( tdata, params, &(ops->herk) ); libblis_test_her2k( tdata, params, &(ops->her2k) ); @@ -408,6 +411,7 @@ void libblis_test_read_ops_file( char* input_filename, test_ops_t* ops ) // Level-3 libblis_test_read_op_info( ops, input_stream, BLIS_GEMM, BLIS_TEST_DIMS_MNK, 2, &(ops->gemm) ); + libblis_test_read_op_info( ops, input_stream, BLIS_GEMMT, BLIS_TEST_DIMS_MK, 3, &(ops->gemmt) ); libblis_test_read_op_info( ops, input_stream, BLIS_HEMM, BLIS_TEST_DIMS_MN, 4, &(ops->hemm) ); libblis_test_read_op_info( ops, input_stream, BLIS_HERK, BLIS_TEST_DIMS_MK, 2, &(ops->herk) ); libblis_test_read_op_info( ops, input_stream, BLIS_HER2K, BLIS_TEST_DIMS_MK, 3, &(ops->her2k) ); @@ -546,26 +550,6 @@ void libblis_test_read_params_file( char* input_filename, test_params_t* params libblis_test_read_next_line( buffer, input_stream ); sscanf( buffer, "%u ", &(params->p_inc) ); - // Read whether to enable 3mh. - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_3MH ]) ); - - // Read whether to enable 3m1. - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_3M1 ]) ); - - // Read whether to enable 4mh. - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_4MH ]) ); - - // Read whether to enable 4m1b (4mb). - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_4M1B ]) ); - - // Read whether to enable 4m1a (4m1). - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_4M1A ]) ); - // Read whether to enable 1m. libblis_test_read_next_line( buffer, input_stream ); sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_1M ]) ); @@ -585,24 +569,13 @@ void libblis_test_read_params_file( char* input_filename, test_params_t* params // threads. if ( params->n_app_threads > 1 ) { - if ( params->ind_enable[ BLIS_3MH ] || - params->ind_enable[ BLIS_3M1 ] || - params->ind_enable[ BLIS_4MH ] || - params->ind_enable[ BLIS_4M1B ] || - params->ind_enable[ BLIS_4M1A ] || - params->ind_enable[ BLIS_1M ] - ) + if ( params->ind_enable[ BLIS_1M ] ) { // Due to an inherent race condition in the way induced methods // are enabled and disabled at runtime, all induced methods must be // disabled when simulating multiple application threads. libblis_test_printf_infoc( "simulating multiple application threads; disabling induced methods.\n" ); - params->ind_enable[ BLIS_3MH ] = 0; - params->ind_enable[ BLIS_3M1 ] = 0; - params->ind_enable[ BLIS_4MH ] = 0; - params->ind_enable[ BLIS_4M1B ] = 0; - params->ind_enable[ BLIS_4M1A ] = 0; params->ind_enable[ BLIS_1M ] = 0; } } @@ -663,7 +636,7 @@ void libblis_test_read_op_info( test_ops_t* ops, int i, p; // Initialize the operation type field. - op->opid = opid; + op->opid = opid; // Read the line for the overall operation switch. libblis_test_read_next_line( buffer, input_stream ); @@ -698,7 +671,7 @@ void libblis_test_read_op_info( test_ops_t* ops, //printf( "buffer[p]: %s\n", &buffer[p] ); // Advance until we hit non-whitespace (ie: the next number). - for ( ; isspace( buffer[p] ); ++p ) ; + for ( ; isspace( buffer[p] ); ++p ) ; //printf( "buffer[p] after: %s\n", &buffer[p] ); @@ -707,7 +680,7 @@ void libblis_test_read_op_info( test_ops_t* ops, //printf( "dim[%d] = %d\n", i, op->dim_spec[i] ); // Advance until we hit whitespace (ie: the space before the next number). - for ( ; !isspace( buffer[p] ); ++p ) ; + for ( ; !isspace( buffer[p] ); ++p ) ; } } @@ -805,11 +778,11 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) // convert these values into strings, with "unset" being used if the // value returned was -1 (indicating the environment variable was unset). dim_t nt = bli_thread_get_num_threads(); - dim_t jc_nt = bli_thread_get_jc_nt(); - dim_t pc_nt = bli_thread_get_pc_nt(); - dim_t ic_nt = bli_thread_get_ic_nt(); - dim_t jr_nt = bli_thread_get_jr_nt(); - dim_t ir_nt = bli_thread_get_ir_nt(); + dim_t jc_nt = bli_thread_get_jc_nt(); + dim_t pc_nt = bli_thread_get_pc_nt(); + dim_t ic_nt = bli_thread_get_ic_nt(); + dim_t jr_nt = bli_thread_get_jr_nt(); + dim_t ir_nt = bli_thread_get_ir_nt(); if ( nt == -1 ) sprintf( nt_str, "unset" ); else sprintf( nt_str, "%d", ( int ) nt ); @@ -829,12 +802,12 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) rntm_t gemm, herk, trmm_l, trmm_r, trsm_l, trsm_r; dim_t m = 1000, n = 1000, k = 1000; - bli_thread_init_rntm( &gemm ); - bli_thread_init_rntm( &herk ); - bli_thread_init_rntm( &trmm_l ); - bli_thread_init_rntm( &trmm_r ); - bli_thread_init_rntm( &trsm_l ); - bli_thread_init_rntm( &trsm_r ); + bli_rntm_init_from_global( &gemm ); + bli_rntm_init_from_global( &herk ); + bli_rntm_init_from_global( &trmm_l ); + bli_rntm_init_from_global( &trmm_r ); + bli_rntm_init_from_global( &trsm_l ); + bli_rntm_init_from_global( &trsm_r ); bli_rntm_set_ways_for_op( BLIS_GEMM, BLIS_LEFT, m, n, k, &gemm ); bli_rntm_set_ways_for_op( BLIS_HERK, BLIS_LEFT, m, n, k, &herk ); @@ -869,7 +842,8 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) libblis_test_fprintf_c( os, " stack address %d\n", ( int )bli_info_get_stack_buf_align_size() ); libblis_test_fprintf_c( os, " obj_t address %d\n", ( int )bli_info_get_heap_addr_align_size() ); libblis_test_fprintf_c( os, " obj_t stride %d\n", ( int )bli_info_get_heap_stride_align_size() ); - libblis_test_fprintf_c( os, " pool block addr %d\n", ( int )bli_info_get_pool_addr_align_size() ); + libblis_test_fprintf_c( os, " pool block addr A (+offset) %d (+%d)\n", ( int )bli_info_get_pool_addr_align_size_a(), ( int )bli_info_get_pool_addr_offset_size_a() ); + libblis_test_fprintf_c( os, " pool block addr B (+offset) %d (+%d)\n", ( int )bli_info_get_pool_addr_align_size_b(), ( int )bli_info_get_pool_addr_offset_size_b() ); libblis_test_fprintf_c( os, "\n" ); libblis_test_fprintf_c( os, "BLAS/CBLAS compatibility layers \n" ); libblis_test_fprintf_c( os, " BLAS API enabled? %d\n", ( int )bli_info_get_enable_blas() ); @@ -1226,11 +1200,6 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) libblis_test_fprintf_c( os, "problem size: max to test %u\n", params->p_max ); libblis_test_fprintf_c( os, "problem size increment %u\n", params->p_inc ); libblis_test_fprintf_c( os, "complex implementations \n" ); - libblis_test_fprintf_c( os, " 3mh? %u\n", params->ind_enable[ BLIS_3MH ] ); - libblis_test_fprintf_c( os, " 3m1? %u\n", params->ind_enable[ BLIS_3M1 ] ); - libblis_test_fprintf_c( os, " 4mh? %u\n", params->ind_enable[ BLIS_4MH ] ); - libblis_test_fprintf_c( os, " 4m1b (4mb)? %u\n", params->ind_enable[ BLIS_4M1B ] ); - libblis_test_fprintf_c( os, " 4m1a (4m1)? %u\n", params->ind_enable[ BLIS_4M1A ] ); libblis_test_fprintf_c( os, " 1m? %u\n", params->ind_enable[ BLIS_1M ] ); libblis_test_fprintf_c( os, " native? %u\n", params->ind_enable[ BLIS_NAT ] ); libblis_test_fprintf_c( os, "simulated app-level threads %u\n", params->n_app_threads ); @@ -1770,7 +1739,7 @@ void libblis_test_op_driver = ( char* ) malloc( ( n_operands + 1 ) * sizeof( char ) ); for ( o = 0; o < n_operands; ++o ) - { + { unsigned int ij; operand_t operand_type = libblis_test_get_operand_type_for_char( o_types[o] ); @@ -1785,8 +1754,8 @@ void libblis_test_op_driver } } - // Enumerate all combinations of datatype domains requested, but only - // for the gemm operation. + // Enumerate all combinations of datatypes requested, but only for the + // gemm operation. if ( !mixed_domain && mixed_precision && op->opid == BLIS_GEMM ) { @@ -2108,6 +2077,11 @@ void libblis_test_op_driver // Loop over the requested storage schemes. for ( sci = 0; sci < n_store_combos; ++sci ) + //for ( sci = 0; sci < 5; ( sci == 0 || sci == 2 ? sci+=2 : ++sci ) ) + //for ( sci = 0; sci < 5; ( sci == 2 ? sci+=2 : ++sci ) ) + //for ( sci = 3; sci < 8; ( sci == 3 ? sci+=2 : ++sci ) ) + //for ( sci = 0; sci < 1; ++sci ) + //for ( sci = 7; sci < 8; ++sci ) { // Loop over the requested datatypes. for ( dci = 0; dci < n_dt_combos; ++dci ) @@ -2207,7 +2181,7 @@ void libblis_test_op_driver ind_str = bli_ind_oper_get_avail_impl_string( op->opid, datatype ); // Loop over the requested parameter combinations. - for ( pci = 0; pci < n_param_combos; ++pci ) + for ( pci = 0; pci < n_param_combos; ++pci ) { // Loop over the requested problem sizes. for ( p_cur = p_first, pi = 1; p_cur <= p_max; p_cur += p_inc, ++pi ) @@ -2378,7 +2352,11 @@ void libblis_test_op_driver // Mark this operation as done. - op->test_done = TRUE; + if ( tdata->id == 0 ) + op->test_done = TRUE; + + // Wait here so that all threads know we are done + bli_pthread_barrier_wait( tdata->barrier ); } @@ -2425,7 +2403,7 @@ void libblis_test_build_function_string if ( strlen( funcname_str ) > MAX_FUNC_STRING_LENGTH ) libblis_test_printf_error( "Function name string length (%d) exceeds maximum (%d).\n", strlen( funcname_str ), MAX_FUNC_STRING_LENGTH ); - + } @@ -2561,13 +2539,13 @@ void fill_string_with_n_spaces( char* str, unsigned int n_spaces ) void libblis_test_mobj_create( test_params_t* params, num_t dt, trans_t trans, char storage, dim_t m, dim_t n, obj_t* a ) { dim_t gs = params->gs_spacing; - bool_t alignment = params->alignment; + bool alignment = params->alignment; siz_t elem_size = bli_dt_size( dt ); dim_t m_trans = m; dim_t n_trans = n; dim_t rs = 1; // Initialization avoids a compiler warning. dim_t cs = 1; // Initialization avoids a compiler warning. - + // Apply the trans parameter to the dimensions (if needed). bli_set_dims_with_trans( trans, m, n, &m_trans, &n_trans ); @@ -2613,12 +2591,9 @@ void libblis_test_mobj_create( test_params_t* params, num_t dt, trans_t trans, c } - -#if 0 -cntl_t* libblis_test_pobj_create( bszid_t bmult_id_m, bszid_t bmult_id_n, invdiag_t inv_diag, pack_t pack_schema, packbuf_t pack_buf, obj_t* a, obj_t* p, cntx_t* cntx ) +cntl_t* libblis_test_pobj_create( bszid_t bmult_id_m, bszid_t bmult_id_n, invdiag_t inv_diag, pack_t pack_schema, packbuf_t pack_buf, obj_t* a, obj_t* p, cntx_t* cntx, rntm_t* rntm ) { - bool_t does_inv_diag; - rntm_t rntm; + bool does_inv_diag; if ( inv_diag == BLIS_NO_INVERT_DIAG ) does_inv_diag = FALSE; else does_inv_diag = TRUE; @@ -2628,7 +2603,6 @@ cntl_t* libblis_test_pobj_create( bszid_t bmult_id_m, bszid_t bmult_id_n, invdia ( NULL, // we don't need the small block allocator from the runtime. NULL, // func ptr is not referenced b/c we don't call via l3 _int(). - bli_packm_blk_var1, bmult_id_m, bmult_id_n, does_inv_diag, @@ -2639,20 +2613,13 @@ cntl_t* libblis_test_pobj_create( bszid_t bmult_id_m, bszid_t bmult_id_n, invdia NULL // no child node needed ); - // Initialize a local-to-BLIS rntm_t. This is simply so we have something - // to pass into bli_l3_packm(). The function doesn't (currently) use the - // runtime object, and even if it did, one with default values would work - // fine here. - bli_rntm_init( &rntm ); - // Pack the contents of A to P. - bli_l3_packm( a, p, cntx, &rntm, cntl, &BLIS_PACKM_SINGLE_THREADED ); + bli_packm_blk_var1( a, p, cntx, rntm, cntl, &BLIS_PACKM_SINGLE_THREADED ); // Return the control tree pointer so the caller can free the cntl_t and its // mem_t entry later on. return cntl; } -#endif void libblis_test_vobj_create( test_params_t* params, num_t dt, char storage, dim_t m, obj_t* x ) @@ -2675,7 +2642,7 @@ void libblis_test_vobj_create( test_params_t* params, num_t dt, char storage, di -void libblis_test_vobj_randomize( test_params_t* params, bool_t normalize, obj_t* x ) +void libblis_test_vobj_randomize( test_params_t* params, bool normalize, obj_t* x ) { if ( params->rand_method == BLIS_TEST_RAND_REAL_VALUES ) bli_randv( x ); @@ -2704,7 +2671,7 @@ void libblis_test_vobj_randomize( test_params_t* params, bool_t normalize, obj_t -void libblis_test_mobj_randomize( test_params_t* params, bool_t normalize, obj_t* a ) +void libblis_test_mobj_randomize( test_params_t* params, bool normalize, obj_t* a ) { if ( params->rand_method == BLIS_TEST_RAND_REAL_VALUES ) bli_randm( a ); @@ -2997,7 +2964,7 @@ void libblis_test_parse_message( FILE* output_stream, char* message, va_list arg char* the_string; char the_char; - // Begin looping over message to insert variables wherever there are + // Begin looping over message to insert variables wherever there are // format specifiers. for ( c = 0; message[c] != '\0'; ) { @@ -3070,8 +3037,8 @@ void libblis_test_parse_message( FILE* output_stream, char* message, va_list arg void libblis_test_parse_command_line( int argc, char** argv ) { - bool_t gave_option_g = FALSE; - bool_t gave_option_o = FALSE; + bool gave_option_g = FALSE; + bool gave_option_o = FALSE; int opt; char opt_ch; getopt_t state; @@ -3177,7 +3144,7 @@ int libblis_test_op_is_disabled( test_op_t* op ) return r_val; } -int libblis_test_op_is_done( test_op_t* op ) +bool libblis_test_op_is_done( test_op_t* op ) { return op->test_done; } diff --git a/testsuite/src/test_libblis.h b/testsuite/src/test_libblis.h index 4a3dd0ffc7..cdb3c6dac4 100644 --- a/testsuite/src/test_libblis.h +++ b/testsuite/src/test_libblis.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -168,8 +168,8 @@ typedef struct unsigned int n_datatypes; char datatype_char[ MAX_NUM_DATATYPES + 1 ]; num_t datatype[ MAX_NUM_DATATYPES + 1 ]; - unsigned int mixed_domain; - unsigned int mixed_precision; + unsigned int mixed_domain; + unsigned int mixed_precision; unsigned int p_first; unsigned int p_max; unsigned int p_inc; @@ -198,7 +198,7 @@ typedef struct int dim_aux[ MAX_NUM_DIMENSIONS ]; unsigned int n_params; char params[ MAX_NUM_PARAMETERS ]; - bool_t test_done; + bool test_done; } test_op_t; @@ -273,6 +273,7 @@ typedef struct test_ops_s // level-3 test_op_t gemm; + test_op_t gemmt; test_op_t hemm; test_op_t herk; test_op_t her2k; @@ -417,13 +418,13 @@ void fill_string_with_n_spaces( char* str, unsigned int n_spaces ); // --- Create object --- void libblis_test_mobj_create( test_params_t* params, num_t dt, trans_t trans, char storage, dim_t m, dim_t n, obj_t* a ); -cntl_t* libblis_test_pobj_create( bszid_t bmult_id_m, bszid_t bmult_id_n, invdiag_t inv_diag, pack_t pack_schema, packbuf_t pack_buf, obj_t* a, obj_t* p, cntx_t* cntx ); +cntl_t* libblis_test_pobj_create( bszid_t bmult_id_m, bszid_t bmult_id_n, invdiag_t inv_diag, pack_t pack_schema, packbuf_t pack_buf, obj_t* a, obj_t* p, cntx_t* cntx, rntm_t* rntm ); void libblis_test_vobj_create( test_params_t* params, num_t dt, char storage, dim_t m, obj_t* x ); // --- Randomize/initialize object --- -void libblis_test_vobj_randomize( test_params_t* params, bool_t normalize, obj_t* x ); -void libblis_test_mobj_randomize( test_params_t* params, bool_t normalize, obj_t* a ); +void libblis_test_vobj_randomize( test_params_t* params, bool normalize, obj_t* x ); +void libblis_test_mobj_randomize( test_params_t* params, bool normalize, obj_t* a ); void libblis_test_mobj_load_diag( test_params_t* params, obj_t* a ); void libblis_test_ceil_pow2( obj_t* alpha ); @@ -460,22 +461,22 @@ void libblis_test_parse_command_line( int argc, char** argv ); void libblis_test_check_empty_problem( obj_t* c, double* perf, double* resid ); int libblis_test_op_is_disabled( test_op_t* op ); -int libblis_test_op_is_done( test_op_t* op ); -int libblis_test_util_is_disabled( test_op_t* op ); -int libblis_test_l1v_is_disabled( test_op_t* op ); -int libblis_test_l1m_is_disabled( test_op_t* op ); -int libblis_test_l1f_is_disabled( test_op_t* op ); -int libblis_test_l2_is_disabled( test_op_t* op ); -int libblis_test_l3ukr_is_disabled( test_op_t* op ); -int libblis_test_l3_is_disabled( test_op_t* op ); -int libblis_test_dt_str_has_sp_char( test_params_t* params ); -int libblis_test_dt_str_has_sp_char_str( int n, char* str ); -int libblis_test_dt_str_has_dp_char( test_params_t* params ); -int libblis_test_dt_str_has_dp_char_str( int n, char* str ); -int libblis_test_dt_str_has_rd_char( test_params_t* params ); -int libblis_test_dt_str_has_rd_char_str( int n, char* str ); -int libblis_test_dt_str_has_cd_char( test_params_t* params ); -int libblis_test_dt_str_has_cd_char_str( int n, char* str ); +bool libblis_test_op_is_done( test_op_t* op ); +int libblis_test_util_is_disabled( test_op_t* op ); +int libblis_test_l1v_is_disabled( test_op_t* op ); +int libblis_test_l1m_is_disabled( test_op_t* op ); +int libblis_test_l1f_is_disabled( test_op_t* op ); +int libblis_test_l2_is_disabled( test_op_t* op ); +int libblis_test_l3ukr_is_disabled( test_op_t* op ); +int libblis_test_l3_is_disabled( test_op_t* op ); +int libblis_test_dt_str_has_sp_char( test_params_t* params ); +int libblis_test_dt_str_has_sp_char_str( int n, char* str ); +int libblis_test_dt_str_has_dp_char( test_params_t* params ); +int libblis_test_dt_str_has_dp_char_str( int n, char* str ); +int libblis_test_dt_str_has_rd_char( test_params_t* params ); +int libblis_test_dt_str_has_rd_char_str( int n, char* str ); +int libblis_test_dt_str_has_cd_char( test_params_t* params ); +int libblis_test_dt_str_has_cd_char_str( int n, char* str ); unsigned int libblis_test_count_combos ( @@ -546,6 +547,7 @@ char libblis_test_proj_dtchar_to_precchar( char dt_char ); // Level-3 #include "test_gemm.h" +#include "test_gemmt.h" #include "test_hemm.h" #include "test_herk.h" #include "test_her2k.h" diff --git a/testsuite/src/test_normfm.c b/testsuite/src/test_normfm.c index e8882ed54e..c4b9a0105e 100644 --- a/testsuite/src/test_normfm.c +++ b/testsuite/src/test_normfm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -259,7 +259,7 @@ void libblis_test_normfm_check // // Under these conditions, we assume that the implementation for // - // norm := normf( x ) + // norm := normfm( x ) // // is functioning correctly if // diff --git a/testsuite/src/test_normfm.h b/testsuite/src/test_normfm.h index b79f6b7bb0..a24b5e5ba2 100644 --- a/testsuite/src/test_normfm.h +++ b/testsuite/src/test_normfm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_normfv.c b/testsuite/src/test_normfv.c index 1622a2e897..3bcce35af4 100644 --- a/testsuite/src/test_normfv.c +++ b/testsuite/src/test_normfv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -256,7 +256,7 @@ void libblis_test_normfv_check // // Under these conditions, we assume that the implementation for // - // norm := normf( x ) + // norm := normfv( x ) // // is functioning correctly if // diff --git a/testsuite/src/test_normfv.h b/testsuite/src/test_normfv.h index 2193c43ee3..afa5350063 100644 --- a/testsuite/src/test_normfv.h +++ b/testsuite/src/test_normfv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randm.c b/testsuite/src/test_randm.c index f5ef20629d..223007dba9 100644 --- a/testsuite/src/test_randm.c +++ b/testsuite/src/test_randm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randm.h b/testsuite/src/test_randm.h index 9c8c87886a..e444649629 100644 --- a/testsuite/src/test_randm.h +++ b/testsuite/src/test_randm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randv.c b/testsuite/src/test_randv.c index 03f98a9b9f..951c8c3eca 100644 --- a/testsuite/src/test_randv.c +++ b/testsuite/src/test_randv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randv.h b/testsuite/src/test_randv.h index 3574c19e56..bb658dfd7c 100644 --- a/testsuite/src/test_randv.h +++ b/testsuite/src/test_randv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scal2m.c b/testsuite/src/test_scal2m.c index d6f29f996c..e8440fc46d 100644 --- a/testsuite/src/test_scal2m.c +++ b/testsuite/src/test_scal2m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -288,7 +288,7 @@ void libblis_test_scal2m_check // // is functioning correctly if // - // normf( y - alpha * conjx(x) ) + // normfm( y - alpha * conjx(x) ) // // is negligible. // diff --git a/testsuite/src/test_scal2m.h b/testsuite/src/test_scal2m.h index 3abcd9a14f..262723f4e7 100644 --- a/testsuite/src/test_scal2m.h +++ b/testsuite/src/test_scal2m.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scal2v.c b/testsuite/src/test_scal2v.c index 7a28479dbd..c200e13fcb 100644 --- a/testsuite/src/test_scal2v.c +++ b/testsuite/src/test_scal2v.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -285,7 +285,7 @@ void libblis_test_scal2v_check // // is functioning correctly if // - // normf( y - alpha * conjx(x) ) + // normfv( y - alpha * conjx(x) ) // // is negligible. // diff --git a/testsuite/src/test_scal2v.h b/testsuite/src/test_scal2v.h index 3ab6b3c42a..75b5cfe4a6 100644 --- a/testsuite/src/test_scal2v.h +++ b/testsuite/src/test_scal2v.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scalm.c b/testsuite/src/test_scalm.c index 3e9d5069f6..6219c71df4 100644 --- a/testsuite/src/test_scalm.c +++ b/testsuite/src/test_scalm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -280,7 +280,7 @@ void libblis_test_scalm_check // // is functioning correctly if // - // normf( y + -conjbeta(beta) * y_orig ) + // normfm( y + -conjbeta(beta) * y_orig ) // // is negligible. // diff --git a/testsuite/src/test_scalm.h b/testsuite/src/test_scalm.h index 1723f51dcb..3b98617b29 100644 --- a/testsuite/src/test_scalm.h +++ b/testsuite/src/test_scalm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scalv.c b/testsuite/src/test_scalv.c index ef3b980cae..142b5e410b 100644 --- a/testsuite/src/test_scalv.c +++ b/testsuite/src/test_scalv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -276,7 +276,7 @@ void libblis_test_scalv_check // // is functioning correctly if // - // normf( y + -conjbeta(beta) * y_orig ) + // normfv( y + -conjbeta(beta) * y_orig ) // // is negligible. // diff --git a/testsuite/src/test_scalv.h b/testsuite/src/test_scalv.h index 9092ae3598..144b416759 100644 --- a/testsuite/src/test_scalv.h +++ b/testsuite/src/test_scalv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setm.c b/testsuite/src/test_setm.c index 630ced831e..80cebd64e0 100644 --- a/testsuite/src/test_setm.c +++ b/testsuite/src/test_setm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setm.h b/testsuite/src/test_setm.h index f5cd32aa19..0271840312 100644 --- a/testsuite/src/test_setm.h +++ b/testsuite/src/test_setm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setv.c b/testsuite/src/test_setv.c index a0ed3ee97e..10f0348c75 100644 --- a/testsuite/src/test_setv.c +++ b/testsuite/src/test_setv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setv.h b/testsuite/src/test_setv.h index b2494a17c9..4e02d489e5 100644 --- a/testsuite/src/test_setv.h +++ b/testsuite/src/test_setv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subm.c b/testsuite/src/test_subm.c index d28eb28008..63b48eedcf 100644 --- a/testsuite/src/test_subm.c +++ b/testsuite/src/test_subm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -275,7 +275,7 @@ void libblis_test_subm_check // // is functioning correctly if // - // normfv(y) - sqrt( absqsc( beta - conjx(alpha) ) * m * n ) + // normfm(y) - sqrt( absqsc( beta - conjx(alpha) ) * m * n ) // // is negligible. // diff --git a/testsuite/src/test_subm.h b/testsuite/src/test_subm.h index c7e7e93cee..e39eff8282 100644 --- a/testsuite/src/test_subm.h +++ b/testsuite/src/test_subm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subv.c b/testsuite/src/test_subv.c index 7d7a107dda..3a48f02a46 100644 --- a/testsuite/src/test_subv.c +++ b/testsuite/src/test_subv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subv.h b/testsuite/src/test_subv.h index 30fd8bba88..5dbe465898 100644 --- a/testsuite/src/test_subv.h +++ b/testsuite/src/test_subv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_symm.c b/testsuite/src/test_symm.c index f64ad7e76a..03d74e8691 100644 --- a/testsuite/src/test_symm.c +++ b/testsuite/src/test_symm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -202,13 +202,13 @@ void libblis_test_symm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, n, &b ); + sc_str[2], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -287,8 +287,6 @@ void libblis_test_symm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_symm( side, alpha, a, b, beta, c ); - //bli_symm4m( side, alpha, a, b, beta, c ); - //bli_symm3m( side, alpha, a, b, beta, c ); break; default: @@ -338,7 +336,7 @@ void libblis_test_symm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_symm.h b/testsuite/src/test_symm.h index fe960d4fea..bf50bf65d7 100644 --- a/testsuite/src/test_symm.h +++ b/testsuite/src/test_symm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_symv.c b/testsuite/src/test_symv.c index c654685dfd..5ae5f30be0 100644 --- a/testsuite/src/test_symv.c +++ b/testsuite/src/test_symv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -322,7 +322,7 @@ void libblis_test_symv_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where // diff --git a/testsuite/src/test_symv.h b/testsuite/src/test_symv.h index 0a0a833c52..5dba0624ca 100644 --- a/testsuite/src/test_symv.h +++ b/testsuite/src/test_symv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr.c b/testsuite/src/test_syr.c index efdc67b842..69376b9708 100644 --- a/testsuite/src/test_syr.c +++ b/testsuite/src/test_syr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -301,7 +301,7 @@ void libblis_test_syr_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_syr.h b/testsuite/src/test_syr.h index d616f969b1..455e18ff1d 100644 --- a/testsuite/src/test_syr.h +++ b/testsuite/src/test_syr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr2.c b/testsuite/src/test_syr2.c index e87cd13e52..42d65c00e4 100644 --- a/testsuite/src/test_syr2.c +++ b/testsuite/src/test_syr2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -313,7 +313,7 @@ void libblis_test_syr2_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_syr2.h b/testsuite/src/test_syr2.h index 6f0998354e..d6c1f3c104 100644 --- a/testsuite/src/test_syr2.h +++ b/testsuite/src/test_syr2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr2k.c b/testsuite/src/test_syr2k.c index 9b57502b74..2e1fcf2374 100644 --- a/testsuite/src/test_syr2k.c +++ b/testsuite/src/test_syr2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -195,13 +195,13 @@ void libblis_test_syr2k_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, k, &b ); + sc_str[2], m, k, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -285,8 +285,6 @@ void libblis_test_syr2k_impl { case BLIS_TEST_SEQ_FRONT_END: bli_syr2k( alpha, a, b, beta, c ); - //bli_syr2k4m( alpha, a, b, beta, c ); - //bli_syr2k3m( alpha, a, b, beta, c ); break; default: @@ -335,7 +333,7 @@ void libblis_test_syr2k_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_syr2k.h b/testsuite/src/test_syr2k.h index 9ceb3befd2..edf893c291 100644 --- a/testsuite/src/test_syr2k.h +++ b/testsuite/src/test_syr2k.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syrk.c b/testsuite/src/test_syrk.c index 0ae4a1802d..be3e33fe31 100644 --- a/testsuite/src/test_syrk.c +++ b/testsuite/src/test_syrk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -192,11 +192,11 @@ void libblis_test_syrk_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -276,8 +276,6 @@ void libblis_test_syrk_impl { case BLIS_TEST_SEQ_FRONT_END: bli_syrk( alpha, a, beta, c ); - //bli_syrk4m( alpha, a, beta, c ); - //bli_syrk3m( alpha, a, beta, c ); break; default: @@ -324,7 +322,7 @@ void libblis_test_syrk_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_syrk.h b/testsuite/src/test_syrk.h index e0d461c107..8cad724566 100644 --- a/testsuite/src/test_syrk.h +++ b/testsuite/src/test_syrk.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmm.c b/testsuite/src/test_trmm.c index 39524b9c10..0504b33158 100644 --- a/testsuite/src/test_trmm.c +++ b/testsuite/src/test_trmm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -197,11 +197,11 @@ void libblis_test_trmm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, transa, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b ); + sc_str[0], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b_save ); + sc_str[0], m, n, &b_save ); // Set alpha and beta. if ( bli_obj_is_real( &b ) ) @@ -272,8 +272,6 @@ void libblis_test_trmm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_trmm( side, alpha, a, b ); - //bli_trmm4m( side, alpha, a, b ); - //bli_trmm3m( side, alpha, a, b ); break; default: @@ -320,7 +318,7 @@ void libblis_test_trmm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trmm.h b/testsuite/src/test_trmm.h index 0c0b617ddc..a84ca1d296 100644 --- a/testsuite/src/test_trmm.h +++ b/testsuite/src/test_trmm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmm3.c b/testsuite/src/test_trmm3.c index 77e8f26497..d0644252ff 100644 --- a/testsuite/src/test_trmm3.c +++ b/testsuite/src/test_trmm3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -204,13 +204,13 @@ void libblis_test_trmm3_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, n, &b ); + sc_str[2], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -288,8 +288,6 @@ void libblis_test_trmm3_impl { case BLIS_TEST_SEQ_FRONT_END: bli_trmm3( side, alpha, a, b, beta, c ); - //bli_trmm34m( side, alpha, a, b, beta, c ); - //bli_trmm33m( side, alpha, a, b, beta, c ); break; default: @@ -339,7 +337,7 @@ void libblis_test_trmm3_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trmm3.h b/testsuite/src/test_trmm3.h index 6150b70237..ee9490036c 100644 --- a/testsuite/src/test_trmm3.h +++ b/testsuite/src/test_trmm3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmv.c b/testsuite/src/test_trmv.c index cd1b130cf1..71acc90ba0 100644 --- a/testsuite/src/test_trmv.c +++ b/testsuite/src/test_trmv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -304,7 +304,7 @@ void libblis_test_trmv_check // // is functioning correctly if // - // normf( y - x ) + // normfv( y - x ) // // is negligible, where // diff --git a/testsuite/src/test_trmv.h b/testsuite/src/test_trmv.h index 185aeb8a38..1fae8331ff 100644 --- a/testsuite/src/test_trmv.h +++ b/testsuite/src/test_trmv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsm.c b/testsuite/src/test_trsm.c index 30435a7737..fa0d8e7c30 100644 --- a/testsuite/src/test_trsm.c +++ b/testsuite/src/test_trsm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -197,11 +197,11 @@ void libblis_test_trsm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, transa, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b ); + sc_str[0], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b_save ); + sc_str[0], m, n, &b_save ); // Set alpha. if ( bli_obj_is_real( &b ) ) @@ -327,7 +327,7 @@ void libblis_test_trsm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trsm.h b/testsuite/src/test_trsm.h index 2738511f2b..ee23b2c7a1 100644 --- a/testsuite/src/test_trsm.h +++ b/testsuite/src/test_trsm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsm_ukr.c b/testsuite/src/test_trsm_ukr.c index 5476e1daf2..b07da91cc8 100644 --- a/testsuite/src/test_trsm_ukr.c +++ b/testsuite/src/test_trsm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -171,7 +171,6 @@ void libblis_test_trsm_ukr_experiment num_t datatype; dim_t m, n; - inc_t ldap, ldbp; char sc_a = 'c'; char sc_b = 'r'; @@ -196,11 +195,6 @@ void libblis_test_trsm_ukr_experiment m = bli_cntx_get_blksz_def_dt( datatype, BLIS_MR, cntx ); n = bli_cntx_get_blksz_def_dt( datatype, BLIS_NR, cntx ); - // Also query PACKMR and PACKNR as the leading dimensions to ap and bp, - // respectively. - ldap = bli_cntx_get_blksz_max_dt( datatype, BLIS_MR, cntx ); - ldbp = bli_cntx_get_blksz_max_dt( datatype, BLIS_NR, cntx ); - // Store the register blocksizes so that the driver can retrieve the // values later when printing results. op->dim_aux[0] = m; @@ -238,7 +232,10 @@ void libblis_test_trsm_ukr_experiment libblis_test_mobj_randomize( params, TRUE, &c ); bli_copym( &c, &c_save ); -#if 0 + rntm_t rntm; + bli_rntm_init( &rntm ); + bli_pba_rntm_set_pba( &rntm ); + // Create pack objects for a and b, and pack them to ap and bp, // respectively. cntl_t* cntl_a = libblis_test_pobj_create @@ -249,40 +246,9 @@ void libblis_test_trsm_ukr_experiment BLIS_PACKED_ROW_PANELS, BLIS_BUFFER_FOR_A_BLOCK, &a, &ap, - cntx + cntx, + &rntm ); - cntl_t* cntl_b = libblis_test_pobj_create - ( - BLIS_MR, - BLIS_NR, - BLIS_NO_INVERT_DIAG, - BLIS_PACKED_COL_PANELS, - BLIS_BUFFER_FOR_B_PANEL, - &b, &bp, - cntx - ); -#endif - - // Create the packed objects. Use packmr and packnr as the leading - // dimensions of ap and bp, respectively. - bli_obj_create( datatype, m, m, 1, ldap, &ap ); - bli_obj_create( datatype, m, n, ldbp, 1, &bp ); - - // Set up the objects for packing. Calling packm_init_pack() does everything - // except checkout a memory pool block and save its address to the obj_t's. - // However, it does overwrite the buffer field of packed object with that of - // the source object. So, we have to save the buffer address that was - // allocated. - void* buf_ap = bli_obj_buffer( &ap ); - void* buf_bp = bli_obj_buffer( &bp ); - bli_packm_init_pack( BLIS_INVERT_DIAG, BLIS_PACKED_ROW_PANELS, - BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_MR, BLIS_KR, &a, &ap, cntx ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, - BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_KR, BLIS_NR, &b, &bp, cntx ); - bli_obj_set_buffer( buf_ap, &ap ); - bli_obj_set_buffer( buf_bp, &bp ); // Set the diagonal offset of ap. bli_obj_set_diag_offset( 0, &ap ); @@ -292,24 +258,35 @@ void libblis_test_trsm_ukr_experiment // know which set of micro-kernels (lower or upper) to choose from. bli_obj_set_uplo( uploa, &ap ); - // Pack the data from the source objects. - bli_packm_blk_var1( &a, &ap, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - #if 0 bli_printm( "a", &a, "%5.2f", "" ); bli_printm( "ap", &ap, "%5.2f", "" ); #endif - // Repeat the experiment n_repeats times and record results. + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { - // Re-pack the contents of b to bp. - //bli_packm_blk_var1( &b, &bp, cntx, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); - bli_copym( &c_save, &c ); + // Transpose B to B^T for packing. + bli_obj_induce_trans( &b ); + + cntl_t* cntl_b = libblis_test_pobj_create + ( + BLIS_NR, + BLIS_MR, + BLIS_NO_INVERT_DIAG, + BLIS_PACKED_COL_PANELS, + BLIS_BUFFER_FOR_B_PANEL, + &b, &bp, + cntx, + &rntm + ); + + // Transpose B^T back to B and Bp^T back to Bp. + bli_obj_induce_trans( &b ); + bli_obj_induce_trans( &bp ); + time = bli_clock(); libblis_test_trsm_ukr_impl( iface, side, @@ -317,6 +294,10 @@ bli_printm( "ap", &ap, "%5.2f", "" ); cntx ); time_min = bli_clock_min_diff( time_min, time ); + + // Free the control tree nodes and release their cached mem_t entries + // back to the memory broker. + bli_cntl_free( &rntm, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); } // Estimate the performance of the best experiment repeat. @@ -329,12 +310,9 @@ bli_printm( "ap", &ap, "%5.2f", "" ); // Zero out performance and residual if output matrix is empty. //libblis_test_check_empty_problem( &c, perf, resid ); -#if 0 // Free the control tree nodes and release their cached mem_t entries // back to the memory broker. - bli_cntl_free( NULL, cntl_a, &BLIS_PACKM_SINGLE_THREADED ); - bli_cntl_free( NULL, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); -#endif + bli_cntl_free( &rntm, cntl_a, &BLIS_PACKM_SINGLE_THREADED ); // Free the test objects. bli_obj_free( &a ); @@ -401,7 +379,7 @@ void libblis_test_trsm_ukr_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trsm_ukr.h b/testsuite/src/test_trsm_ukr.h index 71685bb2a0..22c6676368 100644 --- a/testsuite/src/test_trsm_ukr.h +++ b/testsuite/src/test_trsm_ukr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsv.c b/testsuite/src/test_trsv.c index cb3138c920..12543cd9a0 100644 --- a/testsuite/src/test_trsv.c +++ b/testsuite/src/test_trsv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -305,7 +305,7 @@ void libblis_test_trsv_check // // is functioning correctly if // - // normf( y - x_orig ) + // normfv( y - x_orig ) // // is negligible, where // diff --git a/testsuite/src/test_trsv.h b/testsuite/src/test_trsv.h index b2e85469dc..5f5fa4eb0f 100644 --- a/testsuite/src/test_trsv.h +++ b/testsuite/src/test_trsv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_xpbym.c b/testsuite/src/test_xpbym.c index b7acc654ef..2340b4e11f 100644 --- a/testsuite/src/test_xpbym.c +++ b/testsuite/src/test_xpbym.c @@ -288,7 +288,7 @@ void libblis_test_xpbym_check // // is functioning correctly if // - // normf( y - ( beta * y_orig + conjx(x) ) ) + // normfm( y - ( beta * y_orig + conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_xpbyv.c b/testsuite/src/test_xpbyv.c index fa0abdb828..197de86e71 100644 --- a/testsuite/src/test_xpbyv.c +++ b/testsuite/src/test_xpbyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -283,7 +283,7 @@ void libblis_test_xpbyv_check // // is functioning correctly if // - // normf( y - ( beta * y_orig + conjx(x) ) ) + // normfv( y - ( beta * y_orig + conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_xpbyv.h b/testsuite/src/test_xpbyv.h index 3b2e7bee2b..16eb772164 100644 --- a/testsuite/src/test_xpbyv.h +++ b/testsuite/src/test_xpbyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/travis/cpuid/zen2.def b/travis/cpuid/zen2.def new file mode 100644 index 0000000000..1e2cc63906 --- /dev/null +++ b/travis/cpuid/zen2.def @@ -0,0 +1,87 @@ +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2018, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# CPU: AMD EPYC 7742 +# NOTE: This file was copied from zen.def and then the appropriate bits +# in the first field (eax) of leaf 1 were updated to reflect the Zen2 +# "Rome" processor. See [1] for details. +# [1] https://en.wikichip.org/wiki/amd/cpuid +# +00000000 ******** => 0000000D 68747541 444D4163 69746E65 +00000001 ******** => 00830F12 00400800 7ED8320B 178BFBFF +00000002 ******** => 00000000 00000000 00000000 00000000 +00000003 ******** => 00000000 00000000 00000000 00000000 +00000005 ******** => 00000040 00000040 00000003 00000011 +00000006 ******** => 00000004 00000000 00000001 00000000 +00000007 ******** => 00000000 209C01A9 00000000 00000000 +00000008 ******** => 00000000 00000000 00000000 00000000 +00000009 ******** => 00000000 00000000 00000000 00000000 +0000000A ******** => 00000000 00000000 00000000 00000000 +0000000C ******** => 00000000 00000000 00000000 00000000 +0000000D 00000000 => 00000007 00000340 00000340 00000000 +0000000D 00000001 => 0000000F 00000340 00000000 00000000 +0000000D 00000002 => 00000100 00000240 00000000 00000000 +80000000 ******** => 8000001F 68747541 444D4163 69746E65 +80000001 ******** => 00800F12 40000000 35C233FF 2FD3FBFF +80000002 ******** => 20444D41 43595045 35353720 33205031 +80000003 ******** => 6F432D32 50206572 65636F72 726F7373 +80000004 ******** => 20202020 20202020 20202020 00202020 +80000005 ******** => FF40FF40 FF40FF40 20080140 40040140 +80000006 ******** => 36006400 56006400 02006140 0200C140 +80000007 ******** => 00000000 0000001B 00000000 00006799 +80000008 ******** => 00003030 00000007 0000603F 00000000 +80000009 ******** => 00000000 00000000 00000000 00000000 +8000000A ******** => 00000001 00008000 00000000 0001BCFF +8000000B ******** => 00000000 00000000 00000000 00000000 +8000000C ******** => 00000000 00000000 00000000 00000000 +8000000D ******** => 00000000 00000000 00000000 00000000 +8000000E ******** => 00000000 00000000 00000000 00000000 +8000000F ******** => 00000000 00000000 00000000 00000000 +80000010 ******** => 00000000 00000000 00000000 00000000 +80000011 ******** => 00000000 00000000 00000000 00000000 +80000012 ******** => 00000000 00000000 00000000 00000000 +80000013 ******** => 00000000 00000000 00000000 00000000 +80000014 ******** => 00000000 00000000 00000000 00000000 +80000015 ******** => 00000000 00000000 00000000 00000000 +80000016 ******** => 00000000 00000000 00000000 00000000 +80000017 ******** => 00000000 00000000 00000000 00000000 +80000018 ******** => 00000000 00000000 00000000 00000000 +80000019 ******** => F040F040 00000000 00000000 00000000 +8000001A ******** => 00000003 00000000 00000000 00000000 +8000001B ******** => 000003FF 00000000 00000000 00000000 +8000001C ******** => 00000000 00000000 00000000 00000000 +8000001D 00000000 => 00004121 01C0003F 0000003F 00000000 +8000001D 00000001 => 00004122 00C0003F 000000FF 00000000 +8000001D 00000002 => 00004143 01C0003F 000003FF 00000002 +8000001D 00000003 => 0001C163 03C0003F 00001FFF 00000001 +8000001E ******** => 00000000 00000100 00000300 00000000 +8000001F ******** => 0000000F 0000016F 0000000F 00000001 +8FFFFFFF ******** => 00000000 00000000 00000000 00000000 diff --git a/travis/cpuid/zen3.def b/travis/cpuid/zen3.def new file mode 100644 index 0000000000..ed791813ea --- /dev/null +++ b/travis/cpuid/zen3.def @@ -0,0 +1,87 @@ +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2018, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# CPU: AMD EPYC 7xxx +# NOTE: This file was copied from zen.def and then the appropriate bits +# in the first field (eax) of leaf 1 were updated to reflect the Zen3 +# "Milan" processor. See [1] for details. +# [1] https://en.wikichip.org/wiki/amd/cpuid +# +00000000 ******** => 0000000D 68747541 444D4163 69746E65 +00000001 ******** => 00A00F12 00400800 7ED8320B 178BFBFF +00000002 ******** => 00000000 00000000 00000000 00000000 +00000003 ******** => 00000000 00000000 00000000 00000000 +00000005 ******** => 00000040 00000040 00000003 00000011 +00000006 ******** => 00000004 00000000 00000001 00000000 +00000007 ******** => 00000000 209C01A9 00000000 00000000 +00000008 ******** => 00000000 00000000 00000000 00000000 +00000009 ******** => 00000000 00000000 00000000 00000000 +0000000A ******** => 00000000 00000000 00000000 00000000 +0000000C ******** => 00000000 00000000 00000000 00000000 +0000000D 00000000 => 00000007 00000340 00000340 00000000 +0000000D 00000001 => 0000000F 00000340 00000000 00000000 +0000000D 00000002 => 00000100 00000240 00000000 00000000 +80000000 ******** => 8000001F 68747541 444D4163 69746E65 +80000001 ******** => 00800F12 40000000 35C233FF 2FD3FBFF +80000002 ******** => 20444D41 43595045 35353720 33205031 +80000003 ******** => 6F432D32 50206572 65636F72 726F7373 +80000004 ******** => 20202020 20202020 20202020 00202020 +80000005 ******** => FF40FF40 FF40FF40 20080140 40040140 +80000006 ******** => 36006400 56006400 02006140 0200C140 +80000007 ******** => 00000000 0000001B 00000000 00006799 +80000008 ******** => 00003030 00000007 0000603F 00000000 +80000009 ******** => 00000000 00000000 00000000 00000000 +8000000A ******** => 00000001 00008000 00000000 0001BCFF +8000000B ******** => 00000000 00000000 00000000 00000000 +8000000C ******** => 00000000 00000000 00000000 00000000 +8000000D ******** => 00000000 00000000 00000000 00000000 +8000000E ******** => 00000000 00000000 00000000 00000000 +8000000F ******** => 00000000 00000000 00000000 00000000 +80000010 ******** => 00000000 00000000 00000000 00000000 +80000011 ******** => 00000000 00000000 00000000 00000000 +80000012 ******** => 00000000 00000000 00000000 00000000 +80000013 ******** => 00000000 00000000 00000000 00000000 +80000014 ******** => 00000000 00000000 00000000 00000000 +80000015 ******** => 00000000 00000000 00000000 00000000 +80000016 ******** => 00000000 00000000 00000000 00000000 +80000017 ******** => 00000000 00000000 00000000 00000000 +80000018 ******** => 00000000 00000000 00000000 00000000 +80000019 ******** => F040F040 00000000 00000000 00000000 +8000001A ******** => 00000003 00000000 00000000 00000000 +8000001B ******** => 000003FF 00000000 00000000 00000000 +8000001C ******** => 00000000 00000000 00000000 00000000 +8000001D 00000000 => 00004121 01C0003F 0000003F 00000000 +8000001D 00000001 => 00004122 00C0003F 000000FF 00000000 +8000001D 00000002 => 00004143 01C0003F 000003FF 00000002 +8000001D 00000003 => 0001C163 03C0003F 00001FFF 00000001 +8000001E ******** => 00000000 00000100 00000300 00000000 +8000001F ******** => 0000000F 0000016F 0000000F 00000001 +8FFFFFFF ******** => 00000000 00000000 00000000 00000000 diff --git a/travis/cxx/Makefile b/travis/cxx/Makefile new file mode 100644 index 0000000000..0f8da14e3b --- /dev/null +++ b/travis/cxx/Makefile @@ -0,0 +1,38 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2021, Southern Methodist University +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +.PHONY: all cxx-test + +all: cxx-test + $(CXX) -std=c++0x -o $(BUILD_DIR)/cxx-test.x -I$(INCLUDE_DIR) cxx-test.cxx -L$(LIB_DIR) -lblis diff --git a/travis/cxx/cxx-test.cxx b/travis/cxx/cxx-test.cxx new file mode 100644 index 0000000000..bccbd9e430 --- /dev/null +++ b/travis/cxx/cxx-test.cxx @@ -0,0 +1,50 @@ +// +// +// BLIS +// An object-based framework for developing high-performance BLAS-like +// libraries. +// +// Copyright (C) 2021, Southern Methodist University +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// - Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// - Neither the name(s) of the copyright holder(s) nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// + +#include + +#include "blis.h" + +int main() +{ + const int N = 5; + std::vector A(N*N), B(N*N), C(N*N); + scomplex one{1.0, 0.0}; + scomplex zero{0.0, 0.0}; + + bli_cgemm(BLIS_NO_TRANSPOSE, BLIS_NO_TRANSPOSE, N, N, N, + &one, A.data(), 1, N, + B.data(), 1, N, + &zero, C.data(), 1, N); +} diff --git a/travis/cxx/cxx-test.sh b/travis/cxx/cxx-test.sh new file mode 100755 index 0000000000..c0036611f4 --- /dev/null +++ b/travis/cxx/cxx-test.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2021, Southern Methodist University +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +SOURCE_DIR=$1 +CONFIG=$2 + +if [ -z $SOURCE_DIR ] || [ -z $CONFIG ]; then + echo "usage: cxx-test.sh " + exit 1 +fi + +BUILD_DIR=$(pwd) +INCLUDE_DIR=$BUILD_DIR/include/$CONFIG +LIB_DIR=$BUILD_DIR/lib/$CONFIG + +if [ ! -e $INCLUDE_DIR/blis.h ]; then + echo "could not find blis.h" + exit 1 +fi + +if [ ! -e $SOURCE_DIR/travis/cxx/Makefile ]; then + echo "could not find cxx-test Makefile" + exit 1 +fi + +make -C $SOURCE_DIR/travis/cxx INCLUDE_DIR=$INCLUDE_DIR LIB_DIR=$LIB_DIR BUILD_DIR=$BUILD_DIR diff --git a/travis/do_sde.sh b/travis/do_sde.sh index 6ec9febe5d..de1545886d 100755 --- a/travis/do_sde.sh +++ b/travis/do_sde.sh @@ -3,13 +3,30 @@ set -e set -x -SDE_VERSION=sde-external-8.16.0-2018-01-30-lin +SDE_VERSION=sde-external-8.69.1-2021-07-18-lin SDE_TARBALL=$SDE_VERSION.tar.bz2 SDE=$SDE_VERSION/sde64 -set +x -curl -s -X POST https://content.dropboxapi.com/2/files/download -H "Authorization: Bearer $DROPBOX_TOKEN" -H "Dropbox-API-Arg: {\"path\": \"/$SDE_TARBALL\"}" > $SDE_TARBALL -set -x +# +# This doesn't seem to be necessary anymore +# +#curl --verbose --form accept_license=1 --form form_id=intel_licensed_dls_step_1 \ +# --output /dev/null --cookie-jar jar.txt \ +# --location https://software.intel.com/protected-download/267266/144917 +#curl --verbose --cookie jar.txt --output $SDE_TARBALL \ +# https://software.intel.com/system/files/managed/2a/1a/$SDE_TARBALL + +#curl --verbose --output $SDE_TARBALL \ +# https://software.intel.com/content/dam/develop/external/us/en/documents/downloads/$SDE_TARBALL + +CI_UTILS=ci-utils +CI_UTILS_URL=https://github.com/flame/${CI_UTILS}.git +CI_UTILS_SDE_DIR=sde +SDE_DIRPATH=$CI_UTILS/$CI_UTILS_SDE_DIR + +git clone $CI_UTILS_URL +mv $SDE_DIRPATH/$SDE_TARBALL . + tar xvf $SDE_TARBALL make -j2 testsuite-bin @@ -28,7 +45,8 @@ for LIB in $LD_SO $LIBC_SO $LIBM_SO; do sudo mv .tmp $LIB done -for ARCH in penryn sandybridge haswell skx knl piledriver steamroller excavator zen; do +#for ARCH in penryn sandybridge haswell skx knl piledriver steamroller excavator zen; do +for ARCH in penryn sandybridge haswell skx knl zen zen2 zen3; do if [ "$ARCH" = "knl" ]; then $SDE -knl -- ./test_libblis.x > output.testsuite else diff --git a/travis/do_testsuite.sh b/travis/do_testsuite.sh index bb176b6819..6778f81d85 100755 --- a/travis/do_testsuite.sh +++ b/travis/do_testsuite.sh @@ -8,19 +8,28 @@ export BLIS_IC_NT=2 export BLIS_JR_NT=1 export BLIS_IR_NT=1 -if [ "$TEST" = "FAST" ]; then +if [ "$TEST" = "FAST" -o "$TEST" = "ALL" ]; then make testblis-fast -elif [ "$TEST" = "MD" ]; then + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite +fi + +if [ "$TEST" = "MD" -o "$TEST" = "ALL" ]; then make testblis-md -elif [ "$TEST" = "SALT" ]; then + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite +fi + +if [ "$TEST" = "SALT" -o "$TEST" = "ALL" ]; then # Disable multithreading within BLIS. export BLIS_JC_NT=1 BLIS_IC_NT=1 BLIS_JR_NT=1 BLIS_IR_NT=1 make testblis-salt -else + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite +fi + +if [ "$TEST" = "1" -o "$TEST" = "ALL" ]; then make testblis + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite fi -$DIST_PATH/testsuite/check-blistest.sh ./output.testsuite make testblas $DIST_PATH/blastest/check-blastest.sh diff --git a/vendor/cpp/blis.hh b/vendor/cpp/blis.hh new file mode 100644 index 0000000000..39dc258647 --- /dev/null +++ b/vendor/cpp/blis.hh @@ -0,0 +1,3820 @@ +/****************************************************************************** +* Copyright (c) 2019 - present Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +*******************************************************************************/ + +/*! @file blis.hh + * blis.hh defines all the BLAS CPP templated public interfaces + * */ +#ifndef BLIS_HH +#define BLIS_HH + +#include "cblas.hh" + +namespace blis { + +/*! \brief Construct plane rotation for arbitrary data types + + \b Purpose: + + ROTG construct plane rotation that eliminates b for arbitrary data types, such that \n + + [ z ] = [ c s ] [ a ] \n + [ 0 ] [ -s c ] [ b ] \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + \param[in, out] a + SINGLE/DOUBLE PRECISION REAL + On entry, scalar a. On exit, set to z. + + \param[in, out] b + SINGLE/DOUBLE PRECISION REAL + On entry, scalar b. On exit, set to s, 1/c, or 0. + + \param[out] c + Cosine of rotation; SINGLE/DOUBLE PRECISION REAL. + + \param[out] s + Sine of rotation; SINGLE/DOUBLE PRECISION REAL. + */ +template< typename T > +void rotg( + T *a, + T *b, + T *c, + T *s ) +{ + cblas_rotg(a, b, c, s); +} + +/*! \brief Construct the modified givens transformation matrix for arbitrary data types + + \b Purpose: + + ROTMG construct modified (fast) plane rotation, H, that eliminates b, such that \n + [ z ] = H [ sqrt(d1) 0 ] [ a ] \n + [ 0 ] [ 0 sqrt(d2) ] [ b ] \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + \param[in, out] d1 + SINGLE/DOUBLE PRECISION REAL + sqrt(d1) is scaling factor for vector x. + + \param[in, out] d2 + SINGLE/DOUBLE PRECISION REAL + sqrt(d2) is scaling factor for vector y. + + \param[in, out] a + On entry, scalar a. On exit, set to z. SINGLE/DOUBLE PRECISION REAL. + + \param[in, out] b + On entry, scalar b. SINGLE/DOUBLE PRECISION REAL. + + \param[out] param + SINGLE/DOUBLE PRECISION REAL array, dimension (5),giving parameters + of modified plane rotation + param(1)=DFLAG + param(2)=DH11 + param(3)=DH21 + param(4)=DH12 + param(5)=DH22 + */ +template< typename T > +void rotmg( + T *d1, + T *d2, + T *a, + T b, + T param[5] ) +{ + cblas_rotmg(d1, d2, a, b, param ); +} + +/*! \brief Apply plane rotation for arbitrary data types + + \b Purpose: + + ROT applies a plane rotation: \n + [ x^T ] [ c s ] [ x^T ] \n + [ y^T ] = [ -s c ] [ y^T ] \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + \param[in] n + Number of elements in x and y. n >= 0. + + \param[in, out] x + SINGLE/DOUBLE PRECISION REAL array + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[in, out] y + SINGLE/DOUBLE PRECISION REAL array + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + + \param[in] c + Cosine of rotation; SINGLE/DOUBLE PRECISION REAL. + + \param[in] s + Sine of rotation; SINGLE/DOUBLE PRECISION REAL. + */ +template< typename T > +void rot( + int64_t n, + T *x, int64_t incx, + T *y, int64_t incy, + T c, + T s ) +{ + cblas_rot( n, x, incx, y, incy, c, s ); +} + +/*! \brief Apply the modified givens transformation for arbitrary data types + + \b Purpose: + + ROTM applies modified (fast) plane rotation, H: \n + [ x^T ] = H [ x^T ] \n + [ y^T ] [ y^T ] \n + + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + \param[in] n + Number of elements in x and y. n >= 0. + + \param[in, out] x + SINGLE/DOUBLE PRECISION REAL array + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[in, out] y + SINGLE/DOUBLE PRECISION REAL array + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + + \param[in] P + SINGLE/DOUBLE PRECISION REAL array, dimension (5),giving parameters + of modified plane rotation + param(1)=DFLAG + param(2)=DH11 + param(3)=DH21 + param(4)=DH12 + param(5)=DH22 + */ +template< typename T > +void rotm( + int64_t n, + T *x, int64_t incx, + T *y, int64_t incy, + const T *P) +{ + cblas_rotm( n, x, incx, y, incy, P ); +} + +/*! \brief Interchanges two vectors of arbitrary data types + + \b Purpose: + + SWAP interchanges two vectors uses unrolled loops for increments equal to 1.\n + x <=> y \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER + Number of elements in x and y. n >= 0. + + \param[in] x + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER. + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[in, out] y + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER. + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + */ +template< typename T > +void swap( + int64_t n, + T *x, int64_t incx, + T *y, int64_t incy ) +{ + cblas_swap( n, x, incx, y, incy ); +} + +/*! \brief Scales a vector of arbitrary data types by a constant. + + \b Purpose: + + SCAL scales a vector by a constant, uses unrolled loops for increment equal to 1.\n + x = alpha * x \n + Data precisions of vector & constant include SINGLE/DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER + Number of elements in x. n >= 0. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in ,out] x + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + */ +template< typename TA, typename TB > +void scal( + int64_t n, + TA alpha, + TB* x, int64_t incx ) +{ + cblas_scal( n, alpha, x, incx ); +} + +/*! \brief Copies a vector x to a vector y for arbitrary data types + + \b Purpose: + + COPY copies a vector x to a vector y.\n + y = x \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER + Number of elements in x and y. n >= 0. + + \param[in] x + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER. + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[out] y + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER. + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + */ +template< typename T > +void copy( + int64_t n, + T const *x, int64_t incx, + T *y, int64_t incy ) +{ + cblas_copy( n, x, incx, y, incy ); +} + +/*! \brief Performs addition of scaled vector for arbitrary data types + + \b Purpose: + + AXPY constant times a vector plus a vector.\n + y = alpha*x + y \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER + Number of elements in x and y. n >= 0. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha.\n + If alpha is zero, y is not updated. + + \param[in] x + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER. + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[out] y + REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER. + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + */ +template< typename T > +void axpy( + int64_t n, + T alpha, + T const *x, int64_t incx, + T *y, int64_t incy ) +{ + cblas_axpy( n, alpha, x, incx, y, incy ); +} + +/*! \brief Performs the dot product of two vectors for arbitrary data types + + \b Purpose: + + DOT forms the dot product of two vectors + uses unrolled loops for increments equal to one.\n + dot = x^T * y \n + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + \param[in] n + n is INTEGER + Number of elements in x and y. n >= 0. + + \param[in] x + REAL/DOUBLE PRECISION array. + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER. + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[in] y + REAL/DOUBLE PRECISION array. + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER. + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + + \return Unconjugated dot product, x^T * y. + REAL/DOUBLE PRECISION + */ +template< typename T, typename TR > +TR dot( + int64_t n, + T const *x, int64_t incx, + T const *y, int64_t incy ) +{ + return cblas_dot( n, x, incx, y, incy ); +} + +/*! \brief Performs the dot product of two complex vectors + + \b Purpose: + + DOTU forms the dot product of two complex vectors. \n + CDOTU = X^T * Y \n + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX + + \param[in] n + n is INTEGER + Number of elements in x and y. n >= 0. + + \param[in] x + REAL/DOUBLE PRECISION COMPLEX array. + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER. + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[in] y + REAL/DOUBLE PRECISION COMPLEX array. + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER. + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + + \return Unconjugated dot product, x^T * y. + REAL/DOUBLE PRECISION COMPLEX + */ +template< typename T > +T dotu( + int64_t n, + T const *x, int64_t incx, + T const *y, int64_t incy ) +{ + return cblas_dotu( n, x, incx, y, incy ); +} + +/*! \brief Performs the dot product of two complex vectors + + \b Purpose: + + DOTC forms the dot product of two complex vectors. \n + CDOTU = X^H * Y \n + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX + + \param[in] n + n is INTEGER + Number of elements in x and y. n >= 0. + + \param[in] x + REAL/DOUBLE PRECISION COMPLEX array. + The n-element vector x, in an array of length (n-1)*abs(incx) + 1. + + \param[in] incx + incx is INTEGER. + Stride between elements of x. incx must not be zero. + If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). + + \param[in] y + REAL/DOUBLE PRECISION COMPLEX array. + The n-element vector y, in an array of length (n-1)*abs(incy) + 1. + + \param[in] incy + incy is INTEGER. + Stride between elements of y. incy must not be zero. + If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). + + \return Conjugated dot product, x^H * y. + REAL/DOUBLE PRECISION COMPLEX + */ +template< typename T > +T dotc( + int64_t n, + T const *x, int64_t incx, + T const *y, int64_t incy ) +{ + return cblas_dotc( n, x, incx, y, incy ); +} + +/*! \brief Performs inner product of two vectors with extended precision accumulation + + \b Purpose: + + DOTC forms the inner product of two vectors with extended precision accumulation. \n + Data precisions supported include SINGLE PRECISION REAL + + \param[in] n + n is INTEGER\n + number of elements in input vector(s) + + \param[in] alpha + alpha is REAL\n + single precision scalar to be added to inner product + + \param[in] x + x is REAL array, dimension ( 1 + ( n - 1 )*abs( incx ) )\n + single precision vector with n elements + + \param[in] incx + incx is INTEGER\n + storage spacing between elements of x + + \param[in] y + y is REAL array, dimension ( 1 + ( n - 1 )*abs( incx ) )\n + single precision vector with n elements + + \param[in] incy + incy is INTEGER\n + storage spacing between elements of y + + \return S.P. result with dot product accumulated in D.P. + */ +template< typename T > +T sdsdot( + int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy ) +{ + return cblas_sdsdot( n, alpha, x, incx, y, incy ); +} + +/*! \brief return 2-norm of vectors of arbitrary data types + + \b Purpose: + + NRM2 returns the euclidean norm of a vector via the function name, so that + SNRM2 := sqrt( x'*x ). \n + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER\n + number of elements in input vector(s) + + \param[in] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, + dimension ( 1 + ( n - 1 )*abs( incx ) )\n + single precision vector with n elements + + \param[in] incx + incx is INTEGER\n + storage spacing between elements of x + + \return 2-norm of vector + REAL SINGLE/DOUBLE PRECISION + */ +template< typename T > +real_type +nrm2( + int64_t n, + T const * x, int64_t incx ) +{ + return cblas_nrm2( n, x, incx ); +} + +/*! \brief return 1-norm of vector of arbitrary data types + + \b Purpose: + + ASUM takes the sum of the absolute values, uses unrolled loops for + increment equal to one. \n + ASUM := || Re(x) ||_1 + || Im(x) ||_1. \n + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER\n + number of elements in input vector(s) + + \param[in] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, + dimension ( 1 + ( n - 1 )*abs( incx ) )\n + single precision vector with n elements + + \param[in] incx + incx is INTEGER\n + storage spacing between elements of x + + \return 1-norm of vector + REAL SINGLE/DOUBLE PRECISION + */ +template< typename T > +real_type +asum( + int64_t n, + T const *x, int64_t incx ) +{ + return cblas_asum( n, x, incx ); +} + +/*! \brief Return Index of infinity-norm of vectors of arbitrary types. + + \b Purpose: + + IAMAX finds the index of the first element having maximum |Re(.)| + |Im(.)|. \n + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + \param[in] n + n is INTEGER\n + number of elements in input vector(s) + + \param[in] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, + dimension ( 1 + ( n - 1 )*abs( incx ) ) \n + single precision vector with n elements + + \param[in] incx + incx is INTEGER\n + storage spacing between elements of x + + \return Index of infinity-norm of vector + INTEGER + */ +template< typename T > +int64_t iamax( + int64_t n, + T const *x, int64_t incx ) +{ + return cblas_iamax( n, x, incx ); +} + +/*! \brief Solve General matrix-vector multiply for arbitrary data types + + \b Purpose: + + GEMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + y := alpha*A*x + beta*y, or y := alpha*A**T*x + beta*y, + + where alpha and beta are scalars, x and y are vectors and A is an + m by n matrix. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be used as follows: \n + trans = CBLAS_TRANSPOSE::CblasNoTrans,y := alpha*A*x + beta*y. \n + trans = CBLAS_TRANSPOSE::CblasTrans, y := alpha*A**T*x + beta*y. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, y := alpha*A**T*x + beta*y. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix A. + m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an lda-by-n array [RowMajor: m-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda >= max(1, m) [RowMajor: lda >= max(1, n)]. + + \param[in] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n + If trans = CblasNoTrans: + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Otherwise: + at least ( 1 + ( m - 1 )*abs( incx ) ). + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then y need not be set on input. + + \param[in,out] y + y is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : \n + If trans = CblasNoTrans: + at least ( 1 + ( m - 1 )*abs( incy ) ). \n + Otherwise: + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void gemv( + CBLAS_ORDER layout, + CBLAS_TRANSPOSE trans, + int64_t m, int64_t n, + T alpha, + T const *A, int64_t lda, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_gemv(layout, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +/*! \brief Solve General matrix-vector multiply for arbitrary data types + + \b Purpose: + + GBMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + y := alpha*A*x + beta*y, or y := alpha*A**T*x + beta*y, or + + y := alpha*A**H*x + beta*y, + + where alpha and beta are scalars, x and y are vectors and A is an + m by n matrix with kl sub-diagonals and ku super-diagonals. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be used as follows: \n + trans = CBLAS_TRANSPOSE::CblasNoTrans,y := alpha*A*x + beta*y. \n + trans = CBLAS_TRANSPOSE::CblasTrans, y := alpha*A**T*x + beta*y. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, y := alpha*A**H*x + beta*y. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix A. + m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix A. + n must be at least zero. + + \param[in] kl + kl is INTEGER + On entry, kl specifies the number of sub-diagonals of the matrix A. + kl must be at least zero. + + \param[in] ku + ku is INTEGER + On entry, ku specifies the number of super-diagonals of the matrix A. + ku must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension lda-by-n. + Before entry, the leading ( kl + ku + 1 ) by n part of the + array A must contain the matrix of coefficients, supplied + column by column, with the leading diagonal of the matrix in + row ( ku + 1 ) of the array, the first super-diagonal + starting at position 2 in row ku, the first sub-diagonal + starting at position 1 in row ( ku + 2 ), and so on. + Elements in the array A that do not correspond to elements + in the band matrix (such as the top left ku by ku triangle) + are not referenced. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda >= ( kl + ku + 1 ) + + \param[in] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n + If trans = CblasNoTrans: + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Otherwise: + at least ( 1 + ( m - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then y need not be set on input. + + \param[in,out] y + y is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : \n + If trans = CblasNoTrans: + at least ( 1 + ( m - 1 )*abs( incy ) ). \n + Otherwise: + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void gbmv( + CBLAS_ORDER layout, + CBLAS_TRANSPOSE trans, + int64_t m, int64_t n, + int64_t kl, int64_t ku, + T alpha, + T const *A, int64_t lda, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_gbmv(layout, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +/*! \brief Solves Hermitian matrix-vector multiply for arbitrary data types + + \b Purpose: + + HEMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION COMPLEX, + DOUBLE PRECISION COMPLEX(COMPLEX*16) + + y := alpha*A*x + beta*y, + + where alpha and beta are scalars, x and y are n element vectors and + A is an n by n hermitian matrix. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] alpha + alpha is COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is COMPLEX/COMPLEX*16 array,dimension lda-by-n. \n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular part of the hermitian matrix and the strictly + lower triangular part of A is not referenced. + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular part of the hermitian matrix and the strictly + upper triangular part of A is not referenced. \n + Note that the imaginary parts of the diagonal elements need + not be set and are assumed to be zero. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + + \param[in] x + x is COMPLEX/COMPLEX*16 array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then y need not be set on input. + + \param[in,out] y + y is COMPLEX/COMPLEX*16 array, dimension : \n + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void hemv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *A, int64_t lda, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_hemv(layout, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +/*! \brief Solves Hermitian matrix-vector multiply for arbitrary data types + + \b Purpose: + + HBMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION COMPLEX, + DOUBLE PRECISION COMPLEX(COMPLEX*16) + + y := alpha*A*x + beta*y, + + where alpha and beta are scalars, x and y are n element vectors and + A is an n by n hermitian matrix with k super-diagonals. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the the upper or lower triangular + part of the band matrix A is being supplied as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] k + k is INTEGER + On entry, k specifies the number of super-diagonals of the matrix A. + k must be at least zero. + + \param[in] alpha + alpha is COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is COMPLEX/COMPLEX*16 array,dimension lda-by-n. \n + Before entry with UPLO = CblasUpper, the leading ( k + 1 ) + by n part of the array A must contain the upper triangular + band part of the hermitian matrix, supplied column by + column, with the leading diagonal of the matrix in row + ( k + 1 ) of the array, the first super-diagonal starting at + position 2 in row k, and so on. The top left k by k triangle + of the array A is not referenced. \n + Before entry with UPLO = CblasLower, the leading ( k + 1 ) + by n part of the array A must contain the lower triangular + band part of the hermitian matrix, supplied column by + column, with the leading diagonal of the matrix in row 1 of + the array, the first sub-diagonal starting at position 1 in + row 2, and so on. The bottom right k by k triangle of the + array A is not referenced. \n + Note that the imaginary parts of the diagonal elements need + not be set and are assumed to be zero. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least ( k + 1 ). + + \param[in] x + x is COMPLEX/COMPLEX*16 array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha. + + \param[in,out] y + y is COMPLEX/COMPLEX*16 array, dimension : \n + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void hbmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, int64_t k, + T alpha, + T const *A, int64_t lda, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_hbmv(layout, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +/*! \brief Solves Hermitian matrix-vector multiply for arbitrary data types + + \b Purpose: + + HPMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION COMPLEX, + DOUBLE PRECISION COMPLEX(COMPLEX*16) + + y := alpha*A*x + beta*y, + + where alpha and beta are scalars, x and y are n element vectors and + A is an n by n hermitian matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the the upper or lower triangular + part of the band matrix A is supplied in the packed array Ap as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] alpha + alpha is COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] Ap + Ap is COMPLEX/COMPLEX*16 array,dimension atleast ( ( n*( n + 1 ) )/2 ). \n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular part of the hermitian matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) + and a( 2, 2 ) respectively, and so on. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular part of the hermitian matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) + and a( 3, 1 ) respectively, and so on. \n + Note that the imaginary parts of the diagonal elements need + not be set and are assumed to be zero. + + \param[in] x + x is COMPLEX/COMPLEX*16 array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then y need not be set on input. + + \param[in,out] y + y is COMPLEX/COMPLEX*16 array, dimension : \n + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void hpmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *Ap, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_hpmv(layout, uplo, n, alpha, Ap, x, incx, beta, y, incy); +} + +/*! \brief Solves Symmetric matrix-vector multiply for arbitrary data types + + \b Purpose: + + SYMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL + + y := alpha*A*x + beta*y, + + where alpha and beta are scalars, x and y are n element vectors and + A is an n by n symmetric matrix. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is SINGLE/DOUBLE PRECISION REAL array,dimension lda-by-n. \n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular part of the symmetric matrix and the strictly + lower triangular part of A is not referenced. + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular part of the symmetric matrix and the strictly + upper triangular part of A is not referenced. \n + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is SINGLE/DOUBLE PRECISION REAL + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then y need not be set on input. + + \param[in,out] y + y is SINGLE/DOUBLE PRECISION REAL array, dimension : \n + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void symv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *A, int64_t lda, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_symv(layout, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +/*! \brief Solves symmetric matrix-vector multiply for arbitrary data types + + \b Purpose: + + SBMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL + + y := alpha*A*x + beta*y, + + where alpha and beta are scalars, x and y are n element vectors and + A is an n by n symmetric matrix with k super-diagonals. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the the upper or lower triangular + part of the band matrix A is being supplied as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] k + k is INTEGER + On entry, k specifies the number of super-diagonals of the matrix A. + k must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is SINGLE/DOUBLE PRECISION REAL array,dimension lda-by-n. \n + Before entry with UPLO = CblasUpper, the leading ( k + 1 ) + by n part of the array A must contain the upper triangular + band part of the symmetric matrix, supplied column by + column, with the leading diagonal of the matrix in row + ( k + 1 ) of the array, the first super-diagonal starting at + position 2 in row k, and so on. The top left k by k triangle + of the array A is not referenced. \n + Before entry with UPLO = CblasLower, the leading ( k + 1 ) + by n part of the array A must contain the lower triangular + band part of the symmetric matrix, supplied column by + column, with the leading diagonal of the matrix in row 1 of + the array, the first sub-diagonal starting at position 1 in + row 2, and so on. The bottom right k by k triangle of the + array A is not referenced. \n + Note that the imaginary parts of the diagonal elements need + not be set and are assumed to be zero. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least ( k + 1 ). + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is SINGLE/DOUBLE PRECISION REAL + On entry, beta specifies the scalar alpha. + + \param[in,out] y + y is SINGLE/DOUBLE PRECISION REAL array, dimension : \n + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void sbmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, int64_t k, + T alpha, + T const *A, int64_t lda, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_sbmv(layout, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +/*! \brief Solves symmetric matrix-vector multiply for arbitrary data types + + \b Purpose: + + SPMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL + + y := alpha*A*x + beta*y, + + where alpha and beta are scalars, x and y are n element vectors and + A is an n by n symmetric matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the the upper or lower triangular + part of the band matrix A is supplied in the packed array Ap as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] Ap + Ap is SINGLE/DOUBLE PRECISION REAL array,dimension atleast ( ( n*( n + 1 ) )/2 ). \n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular part of the symmetric matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) + and a( 2, 2 ) respectively, and so on. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular part of the symmetric matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) + and a( 3, 1 ) respectively, and so on. \n + Note that the imaginary parts of the diagonal elements need + not be set and are assumed to be zero. + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] beta + beta is SINGLE/DOUBLE PRECISION REAL + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then y need not be set on input. + + \param[in,out] y + y is SINGLE/DOUBLE PRECISION REAL array, dimension : \n + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry with beta non-zero, the incremented array y + must contain the vector y. On exit, y is overwritten by the + updated vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + */ +template< typename T > +void spmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *Ap, + T const *x, int64_t incx, + T beta, + T *y, int64_t incy ) +{ + cblas_spmv(layout, uplo, n, alpha, Ap, x, incx, beta, y, incy); +} + +/*! \brief Solve the one of the matrix-vector operations for arbitrary data types + + \b Purpose: + + TRMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + x := A*x, or x := A**T*x, + + where x is an n element vector and A is an n by n unit, or non-unit, + upper or lower triangular matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be performed as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, x := A*x. \n + trans = CBLAS_TRANSPOSE::CblasTrans, x := A**T*x. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, x := A**T*x. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: \n + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular matrix and the strictly lower triangular part of + A is not referenced. \n + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular matrix and the strictly upper triangular part of + A is not referenced. \n + Note that when DIAG = CblasUnit, the diagonal elements of + A are not referenced either, but are assumed to be unity. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + + \param[in, out] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x.On exit, x is overwritten with the transformed vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + */ +template< typename T > +void trmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t n, + T const *A, int64_t lda, + T *x, int64_t incx ) +{ + cblas_trmv(layout, uplo, trans, diag, n, A, lda, x, incx); +} + +/*! \brief Solve the one of the matrix-vector operations for arbitrary data types + + \b Purpose: + + TBMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + x := A*x, or x := A**T*x, + + where x is an n element vector and A is an n by n unit, or non-unit, + upper or lower triangular band matrix, with ( k + 1 ) diagonals. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be performed as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, x := A*x. \n + trans = CBLAS_TRANSPOSE::CblasTrans, x := A**T*x. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, x := A**T*x. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: \n + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] k + k is INTEGER + On entry with UPLO = CblasUpper, k specifies the number of + super-diagonals of the matrix A. + On entry with UPLO = CblasLower, k specifies the number of + sub-diagonals of the matrix A. + k must at least zero. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading ( k + 1 ) + by n part of the array A must contain the upper triangular + band part of the matrix of coefficients, supplied column by + column, with the leading diagonal of the matrix in row + ( k + 1 ) of the array, the first super-diagonal starting at + position 2 in row k, and so on. The top left k by k triangle + of the array A is not referenced. \n + Before entry with UPLO = CblasLower, the leading ( k + 1 ) + by n part of the array A must contain the lower triangular + band part of the matrix of coefficients, supplied column by + column, with the leading diagonal of the matrix in row 1 of + the array, the first sub-diagonal starting at position 1 in + row 2, and so on. The bottom right k by k triangle of the + array A is not referenced. \n + Note that when DIAG = CblasUnit the elements of the array A + corresponding to the diagonal elements of the matrix are not + referenced, but are assumed to be unity. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, ( k + 1 ) ). + + \param[in, out] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x.On exit, x is overwritten with the transformed vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + */ +template< typename T > +void tbmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t n, int64_t k, + T const *A, int64_t lda, + T *x, int64_t incx ) +{ + cblas_tbmv(layout, uplo, trans, diag, n, k, A, lda, x, incx); +} + + +/*! \brief Solve the one of the matrix-vector operations for arbitrary data types + + \b Purpose: + + TPMV performs one of the matrix-vector operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + x := A*x, or x := A**T*x, + + where x is an n element vector and A is an n by n unit, or non-unit, + upper or lower triangular matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be performed as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, x := A*x. \n + trans = CBLAS_TRANSPOSE::CblasTrans, x := A**T*x. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, x := A**T*x. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: \n + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] Ap + Ap is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension + ( ( n*( n + 1 ) )/2 ). \n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular matrix packed sequentially, + column by column, so that Ap( 1 ) contains a( 1, 1 ), + Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) and a( 2, 2 ) + respectively, and so on. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular matrix packed sequentially, + column by column, so that Ap( 1 ) contains a( 1, 1 ), + Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) and a( 3, 1 ) + respectively, and so on. \n + Note that when DIAG = CblasUnit, the diagonal elements of + A are not referenced, but are assumed to be unity. + + \param[in, out] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + vector x.On exit, x is overwritten with the transformed vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + */ +template< typename T > +void tpmv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t n, + T const *Ap, + T *x, int64_t incx ) +{ + cblas_tpmv(layout, uplo, trans, diag, n, Ap, x, incx); +} + +/*! \brief Solve the one of the triangular matrix-vector equation for arbitrary data types + + \b Purpose: + + TRSV solves one of the systems of equations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A*x = b, or A**T*x = b, + + where b and x are n element vectors and A is an n by n unit, or + non-unit, upper or lower triangular matrix + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be performed as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, A*x = b. \n + trans = CBLAS_TRANSPOSE::CblasTrans, A**T*x = b. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, A**T*x = b. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: \n + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular matrix and the strictly lower triangular part of + A is not referenced. \n + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular matrix and the strictly upper triangular part of + A is not referenced. \n + Note that when DIAG = CblasUnit, the diagonal elements of + A are not referenced either, but are assumed to be unity. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + + \param[in, out] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + element right-hand side vector b.On exit, x is overwritten + with the transformed vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + */ +template< typename T > +void trsv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t n, + T const *A, int64_t lda, + T *x, int64_t incx ) +{ + cblas_trsv(layout, uplo, trans, diag, n, A, lda, x, incx); +} + +/*! \brief Solve the one of the triangular matrix-vector equation for arbitrary data types + + \b Purpose: + + TBSV solves one of the systems of equations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A*x = b, or A**T*x = b, + + where b and x are n element vectors and A is an n by n unit, or + non-unit, upper or lower triangular band matrix, with ( k + 1 ) + diagonals. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be performed as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, A*x = b. \n + trans = CBLAS_TRANSPOSE::CblasTrans, A**T*x = b. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, A**T*x = b. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: \n + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] k + k is INTEGER + On entry with UPLO = CblasUpper, k specifies the number of + super-diagonals of the matrix A. + On entry with UPLO = CblasLower, k specifies the number of + sub-diagonals of the matrix A. + k must at least zero. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading ( k + 1 ) + by n part of the array A must contain the upper triangular + band part of the matrix of coefficients, supplied column by + column, with the leading diagonal of the matrix in row + ( k + 1 ) of the array, the first super-diagonal starting at + position 2 in row k, and so on. The top left k by k triangle + of the array A is not referenced. \n + Before entry with UPLO = CblasLower, the leading ( k + 1 ) + by n part of the array A must contain the lower triangular + band part of the matrix of coefficients, supplied column by + column, with the leading diagonal of the matrix in row 1 of + the array, the first sub-diagonal starting at position 1 in + row 2, and so on. The bottom right k by k triangle of the + array A is not referenced. \n + Note that when DIAG = CblasUnit, the elements of the array A + corresponding to the diagonal elements of the matrix are not + referenced, but are assumed to be unity. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, k+1 ). + + \param[in, out] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + element right-hand side vector b.On exit, x is overwritten + with the solution vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + */ +template< typename T > +void tbsv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t n, int64_t k, + T const *A, int64_t lda, + T *x, int64_t incx ) +{ + cblas_tbsv(layout, uplo, trans, diag, n, k, A, lda, x, incx); +} + + +/*! \brief Solve the one of the triangular matrix-vector equation for arbitrary data types + + \b Purpose: + + TPSV solves one of the systems of equations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A*x = b, or A**T*x = b, + + where b and x are n element vectors and A is an n by n unit, or + non-unit, upper or lower triangular band matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be performed as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, A*x = b. \n + trans = CBLAS_TRANSPOSE::CblasTrans, A**T*x = b. \n + trans = CBLAS_TRANSPOSE::CblasConjTrans, A**T*x = b. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: \n + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A.n must be at least zero. + + \param[in] Ap + Ap is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension + ( ( n*( n + 1 ) )/2 ). \n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular matrix packed sequentially, + column by column, so that Ap( 1 ) contains a( 1, 1 ), + Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) and a( 2, 2 ) + respectively, and so on. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular matrix packed sequentially, + column by column, so that Ap( 1 ) contains a( 1, 1 ), + Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) and a( 3, 1 ) + respectively, and so on. \n + Note that when DIAG = CblasUnit, the diagonal elements of + A are not referenced, but are assumed to be unity. + + \param[in, out] x + x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the + element right-hand side vector b.On exit, x is overwritten + with the solution vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + */ +template< typename T > +void tpsv( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t n, + T const *Ap, + T *x, int64_t incx ) +{ + cblas_tpsv(layout, uplo, trans, diag, n, Ap, x, incx); +} + +/*! \brief Perform the General matrix rank-1 update for arbitrary data types + + \b Purpose: + + GER performs the rank 1 operation for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + + A := alpha*x*y**T + A, + + where alpha is a scalar, x is an m element vector, y is an n element + vector and A is an m by n matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix A. + m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is REAL/DOUBLE PRECISION array,dimension : + at least ( 1 + ( m - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the m + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is REAL/DOUBLE PRECISION array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] A + A is REAL/DOUBLE PRECISION array,dimension ( lda, n )\n + Before entry, the leading m by n part of the array A must + contain the matrix of coefficients. On exit, A is + overwritten by the updated matrix. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, m ). + */ +template< typename T > +void ger( + CBLAS_ORDER layout, + int64_t m, int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *A, int64_t lda ) +{ + cblas_ger(layout, m, n, alpha, x, incx, y, incy, A, lda); +} + +/*! \brief Perform the General matrix rank-1 update for arbitrary data types + + \b Purpose: + + GERU performs the rank 1 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A := alpha*x*y**T + A, + + where alpha is a scalar, x is an m element vector, y is an n element + vector and A is an m by n matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix A. + m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION COMPLEX + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( m - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the m + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] A + A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n + Before entry, the leading m by n part of the array A must + contain the matrix of coefficients. On exit, A is + overwritten by the updated matrix. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, m ). + */ +template< typename T > +void geru( + CBLAS_ORDER layout, + int64_t m, int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *A, int64_t lda ) +{ + cblas_geru(layout, m, n, alpha, x, incx, y, incy, A, lda); +} + +/*! \brief Perform the General matrix rank-1 update for arbitrary data types + + \b Purpose: + + GERC performs the rank 1 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A := alpha*x*y**T + A, + + where alpha is a scalar, x is an m element vector, y is an n element + vector and A is an m by n matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix A. + m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION COMPLEX + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( m - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the m + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] A + A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n + Before entry, the leading m by n part of the array A must + contain the matrix of coefficients. On exit, A is + overwritten by the updated matrix. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, m ). + */ +template< typename T > +void gerc( + CBLAS_ORDER layout, + int64_t m, int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *A, int64_t lda ) +{ + cblas_gerc(layout, m, n, alpha, x, incx, y, incy, A, lda); +} + +/*! \brief Perform the hermitian rank 1 operation for arbitrary data types + + \b Purpose: + + HER performs the hermitian rank 1 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A := alpha*x*x**H + A, + + where alpha is a real scalar, x is an n element vector, A is an n by n + hermitian matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the upper or lower triangular + part of the array A is to be referenced as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in,out] A + A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular part of the hermitian matrix and the strictly + lower triangular part of A is not referenced. On exit, the + upper triangular part of the array A is overwritten by the + upper triangular part of the updated matrix. \n + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular part of the hermitian matrix and the strictly + upper triangular part of A is not referenced. On exit, the + lower triangular part of the array A is overwritten by the + lower triangular part of the updated matrix. \n + Note that the imaginary parts of the diagonal elements need + not be set, they are assumed to be zero, and on exit they + are set to zero. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + */ +template< typename T > +void her( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + real_type alpha, // zher takes double alpha; use real + T const *x, int64_t incx, + T *A, int64_t lda ) +{ + cblas_her(layout, uplo, n, alpha, x, incx, A, lda); +} + +/*! \brief Perform the hermitian rank 1 operation for arbitrary data types + + \b Purpose: + + HPR performs the hermitian rank 1 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A := alpha*x*x**H + A, + + where alpha is a real scalar, x is an n element vector, A is an n by n + hermitian matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the upper or lower triangular + part of the array A is to be referenced as follows: \n + uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is + supplied in Ap. \n + uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is + supplied in Ap. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in,out] Ap + Ap is SINGLE/DOUBLE PRECISION COMPLEX array,dimension + atleast ( ( n*( n + 1 ) )/2 ).\n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular part of the hermitian matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) + and a( 2, 2 ) respectively, and so on. On exit, the array + Ap is overwritten by the upper triangular part of the + updated matrix. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular part of the hermitian matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) + and a( 3, 1 ) respectively, and so on. On exit, the array + Ap is overwritten by the lower triangular part of the + updated matrix. \n + Note that the imaginary parts of the diagonal elements need + not be set, they are assumed to be zero, and on exit they + are set to zero. + */ +template< typename T > +void hpr( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + real_type alpha, // zher takes double alpha; use real + T const *x, int64_t incx, + T *Ap ) +{ + cblas_hpr(layout, uplo, n, alpha, x, incx, Ap); +} + +/*! \brief Perform the hermitian rank 2 operation for arbitrary data types + + \b Purpose: + + HER2 performs the hermitian rank 2 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A := alpha*x*y**H + conjg( alpha )*y*x**H + A, + + where alpha is a scalar, x and y are n element vector, A is an n by n + hermitian matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies whether the upper or lower triangular part of the + array A is to be referenced as follows: \n + UPLO = CblasUpper Only the upper triangular part of A + is to be referenced. \n + UPLO = CblasLower Only the lower triangular part of A + is to be referenced. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION COMPLEX + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] A + A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular part of the hermitian matrix and the strictly + lower triangular part of A is not referenced. On exit, the + upper triangular part of the array A is overwritten by the + upper triangular part of the updated matrix. \n + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular part of the hermitian matrix and the strictly + upper triangular part of A is not referenced. On exit, the + lower triangular part of the array A is overwritten by the + lower triangular part of the updated matrix. \n + Note that the imaginary parts of the diagonal elements need + not be set, they are assumed to be zero, and on exit they + are set to zero. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + */ +template< typename T > +void her2( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *A, int64_t lda ) +{ + cblas_her2(layout, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +/*! \brief Perform the hermitian rank 2 operation for arbitrary data types + + \b Purpose: + + HPR2 performs the hermitian rank 2 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) + + A := alpha*x*y**H + conjg( alpha )*y*x**H + A, + + where alpha is a scalar, x and y are n element vector, A is an n by n + hermitian matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the upper or lower triangular + part of the array A is to be referenced as follows: \n + uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is + supplied in Ap. \n + uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is + supplied in Ap. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION COMPLEX + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] Ap + Ap is SINGLE/DOUBLE PRECISION COMPLEX array,dimension + atleast ( ( n*( n + 1 ) )/2 ).\n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular part of the hermitian matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) + and a( 2, 2 ) respectively, and so on. On exit, the array + Ap is overwritten by the upper triangular part of the + updated matrix. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular part of the hermitian matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) + and a( 3, 1 ) respectively, and so on. On exit, the array + Ap is overwritten by the lower triangular part of the + updated matrix. \n + Note that the imaginary parts of the diagonal elements need + not be set, they are assumed to be zero, and on exit they + are set to zero. + */ +template< typename T > +void hpr2( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *Ap ) +{ + cblas_hpr2(layout, uplo, n, alpha, x, incx, y, incy, Ap); +} + +/*! \brief Perform the symmetric rank 1 operation for arbitrary data types + + \b Purpose: + + SYR performs the symmetric rank 1 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + A := alpha*x*x**T + A, + + where alpha is a real scalar, x is an n element vector, A is an n by n + symmetric matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the upper or lower triangular + part of the array A is to be referenced as follows: \n + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in,out] A + A is SINGLE/DOUBLE PRECISION REAL array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular part of the symmetric matrix and the strictly + lower triangular part of A is not referenced. On exit, the + upper triangular part of the array A is overwritten by the + upper triangular part of the updated matrix. \n + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular part of the symmetric matrix and the strictly + upper triangular part of A is not referenced. On exit, the + lower triangular part of the array A is overwritten by the + lower triangular part of the updated matrix. \n + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + */ +template< typename T > +void syr( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *x, int64_t incx, + T *A, int64_t lda ) +{ + cblas_syr(layout, uplo, n, alpha, x, incx, A, lda); +} + +/*! \brief Perform the symmetric rank 1 operation for arbitrary data types + + \b Purpose: + + SPR performs the symmetric rank 1 operation for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL + + A := alpha*x*x**T + A, + + where alpha is a real scalar, x is an n element vector, A is an n by n + symmetric matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the upper or lower triangular + part of the array A is to be referenced as follows: \n + uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is + supplied in Ap. \n + uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is + supplied in Ap. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in,out] Ap + Ap is SINGLE/DOUBLE PRECISION REAL array,dimension + atleast ( ( n*( n + 1 ) )/2 ).\n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular part of the symmetric matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) + and a( 2, 2 ) respectively, and so on. On exit, the array + Ap is overwritten by the upper triangular part of the + updated matrix. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular part of the symmetric matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) + and a( 3, 1 ) respectively, and so on. On exit, the array + Ap is overwritten by the lower triangular part of the + updated matrix. \n + */ +template< typename T > +void spr( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *x, int64_t incx, + T *Ap ) +{ + cblas_spr(layout, uplo, n, alpha, x, incx, Ap); +} + +/*! \brief Perform the symmetric rank 2 operation for arbitrary data types + + \b Purpose: + + SYR2 performs the symmetric rank 2 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + A := alpha*x*y**T + alpha*y*x**T + A, + + where alpha is a scalar, x and y are n element vector, A is an n by n + symmetric matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies whether the upper or lower triangular part of the + array A is to be referenced as follows: \n + UPLO = CblasUpper Only the upper triangular part of A + is to be referenced. \n + UPLO = CblasLower Only the lower triangular part of A + is to be referenced. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] A + A is SINGLE/DOUBLE PRECISION REAL array,dimension ( lda, n )\n + Before entry with UPLO = CblasUpper, the leading n by n + upper triangular part of the array A must contain the upper + triangular part of the symmetric matrix and the strictly + lower triangular part of A is not referenced. On exit, the + upper triangular part of the array A is overwritten by the + upper triangular part of the updated matrix. \n + Before entry with UPLO = CblasLower, the leading n by n + lower triangular part of the array A must contain the lower + triangular part of the symmetric matrix and the strictly + upper triangular part of A is not referenced. On exit, the + lower triangular part of the array A is overwritten by the + lower triangular part of the updated matrix. \n + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + lda must be at least max( 1, n ). + */ +template< typename T > +void syr2( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *A, int64_t lda ) +{ + cblas_syr2(layout, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +/*! \brief Perform the symmetric rank 2 operation for arbitrary data types + + \b Purpose: + + SPR2 performs the symmetric rank 2 operation for arbitrary data types + Data precisions supported include SINGLE/DOUBLE PRECISION REAL + + A := alpha*x*y**T + alpha*y*x**T + A, + + where alpha is a scalar, x and y are n element vector, A is an n by n + symmetric matrix, supplied in packed form. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO. + uplo specifies specifies whether the upper or lower triangular + part of the array A is to be referenced as follows: \n + uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is + supplied in Ap. \n + uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is + supplied in Ap. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix A. + n must be at least zero. + + \param[in] alpha + alpha is SINGLE/DOUBLE PRECISION REAL + On entry, alpha specifies the scalar alpha. + + \param[in] x + x is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incx ) ). \n + Before entry, the incremented array x must contain the n + element vector x. + + \param[in] incx + incx is INTEGER + On entry, incx specifies the increment for the elements of x. + incx must not be zero. + + \param[in] y + y is SINGLE/DOUBLE PRECISION REAL array,dimension : + at least ( 1 + ( n - 1 )*abs( incy ) ). \n + Before entry, the incremented array y must contain the n + element vector y. + + \param[in] incy + incy is INTEGER + On entry, incy specifies the increment for the elements of y. + incy must not be zero. + + \param[in,out] Ap + Ap is SINGLE/DOUBLE PRECISION REAL array,dimension + atleast ( ( n*( n + 1 ) )/2 ).\n + Before entry with UPLO = CblasUpper, the array Ap must + contain the upper triangular part of the symmetric matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) + and a( 2, 2 ) respectively, and so on. On exit, the array + Ap is overwritten by the upper triangular part of the + updated matrix. \n + Before entry with UPLO = CblasLower, the array Ap must + contain the lower triangular part of the symmetric matrix + packed sequentially, column by column, so that Ap( 1 ) + contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) + and a( 3, 1 ) respectively, and so on. On exit, the array + Ap is overwritten by the lower triangular part of the + updated matrix. \n + */ +template< typename T > +void spr2( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + int64_t n, + T alpha, + T const *x, int64_t incx, + T const *y, int64_t incy, + T *Ap ) +{ + cblas_spr2(layout, uplo, n, alpha, x, incx, y, incy, Ap); +} + +/*! \brief General matrix-matrix multiply for arbitrary data types + + \b Purpose: + + GEMM performs general matrix-matrix multiply for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*op( A )*op( B ) + beta*C, + + where op( X ) is one of + + op( X ) = X or op( X ) = X**T or op( X ) = X**H, + + alpha and beta are scalars, and A, B and C are matrices, with op( A ) + an m by k matrix, op( B ) a k by n matrix and C an m by n matrix. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] transA + transA is CBLAS_TRANSPOSE + On entry, transA specifies the form of op( A ) to be used in + the matrix multiplication as follows: + transA = CBLAS_TRANSPOSE::CblasNoTrans, op( A ) = A. + transA = CBLAS_TRANSPOSE::CblasTrans, op( A ) = A**T. + transA = CBLAS_TRANSPOSE::CblasConjTrans, op( A ) = A**H. + + \param[in] transB + transB is CBLAS_TRANSPOSE + On entry, transB specifies the form of op( B ) to be used in + the matrix multiplication as follows: + transB = CBLAS_TRANSPOSE::CblasNoTrans, op( B ) = B. + transB = CBLAS_TRANSPOSE::CblasTrans, op( B ) = B**T. + transB = CBLAS_TRANSPOSE::CblasConjTrans, op( B ) = B**H. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix + op( A ) and of the matrix C. m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix + op( B ) and the number of columns of the matrix C. n must be + at least zero. + + \param[in] k + k is INTEGER + On entry, k specifies the number of columns of the matrix + op( A ) and the number of rows of the matrix op( B ). k must + be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If transA = CblasNoTrans: + m-by-k , stored in an lda-by-k array [RowMajor: m-by-lda]. + Otherwise: + k-by-m , stored in an lda-by-m array [RowMajor: k-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If transA = CblasNoTrans: lda >= max(1, m) [RowMajor: lda >= max(1, k)]. + Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, m)]. + + \param[in] B + B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If transA = CblasNoTrans: + k-by-n , stored in an ldb-by-n array [RowMajor: k-by-ldb]. + Otherwise: + n-by-k , stored in an ldb-by-k array [RowMajor: n-by-ldb]. + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + If transA = CblasNoTrans: ldb >= max(1, k) [RowMajor: ldb >= max(1, n)]. + Otherwise: ldb >= max(1, n) [RowMajor: ldb >= max(1, k)]. + + \param[in] beta + beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then C need not be set on input. + + \param[in,out] C + C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : + m-by-n stored in an ldc-by-n array [RowMajor: m-by-ldc]. + Before entry, the leading m by n part of the array C must + contain the matrix C, except when beta is zero, in which + case C need not be set on entry. + On exit, the array C is overwritten by the m by n matrix + ( alpha*op( A )*op( B ) + beta*C ). + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the first dimension of C + ldc >= max(1, m) [RowMajor: ldc >= max(1, n)]. + */ +template< typename T > +void gemm( + CBLAS_ORDER layout, + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int64_t m, int64_t n, int64_t k, + T alpha, + T const *A, int64_t lda, + T const *B, int64_t ldb, + T beta, + T *C, int64_t ldc ) +{ + cblas_gemm(layout, transA, transB, m, n, k, alpha, A,lda, B, ldb, beta, C, ldc); +} + +/*! \brief Solve the triangular matrix-matrix equation for arbitrary data types + + \b Purpose: + + TRSM performs one of the matrix equations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + op( A )*X = alpha*B, or X*op( A ) = alpha*B, + + where alpha is a scalar, X and B are m by n matrices, A is a unit, or + non-unit, upper or lower triangular matrix and op( A ) is one of + where op( X ) is one of + + op( A ) = A or op( A ) = A**T or op( A ) = A**H. + + The matrix X is overwritten on B. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] side + side is enum CBLAS_SIDE + side specifies specifies whether op( A ) appears on the left + or right of X as follows: + side = CBLAS_SIDE::CblasLeft op( A )*X = alpha*B. + side = CBLAS_SIDE::CblasRight op( A )*X = alpha*B. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the matrix A is an upper or + lower triangular matrix as follows: + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the form of op( A ) to be used in + the matrix multiplication as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, op( A ) = A. + trans = CBLAS_TRANSPOSE::CblasTrans, op( A ) = A**T. + trans = CBLAS_TRANSPOSE::CblasConjTrans, op( A ) = A**H. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular. + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix + B. m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix + B. n must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If side = CblasLeft: + the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. + If side = CblasRight: + the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If side = CblasLeft: lda >= max(1, m) . + If side = CblasRight:lda >= max(1, k) . + + \param[in,out] B + B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. + on exit is overwritten by the solution matrix X. + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. + */ +template< typename T > +void trsm( + CBLAS_ORDER layout, + CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t m, + int64_t n, + T alpha, + T const *A, int64_t lda, + T *B, int64_t ldb ) +{ + cblas_trsm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} +/*! \brief Solve the Triangular matrix-matrix multiply for arbitrary data types + + \b Purpose: + + TRMM performs solves one of the matrix equations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + B := alpha*op( A )*B, or B := alpha*B*op( A ), + + where alpha is a scalar, B is an m by n matrices, A is a unit, or + non-unit, upper or lower triangular matrix and op( A ) is one of + op( A ) = A or op( A ) = A**T. + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] side + side is enum CBLAS_SIDE + side specifies whether op( A ) multiplies B from left or right of X + as follows: + side = CBLAS_SIDE::CblasLeft B := alpha*op( A )*B. + side = CBLAS_SIDE::CblasRight B := alpha*B*op( A ). + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies whether the matrix A is an upper or lower triangular + matrix as follows: + uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. + uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the form of op( A ) to be used in + the matrix multiplication as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, op( A ) = A. + trans = CBLAS_TRANSPOSE::CblasTrans, op( A ) = A**T. + trans = CBLAS_TRANSPOSE::CblasConjTrans, op( A ) = A**T. + + \param[in] diag + diag is enum CBLAS_DIAG + diag specifies specifies whether or not A is unit triangular + as follows: + diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular. + diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit + triangular. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix + B. m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix + B. n must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha.When alpha is + zero then A is not referenced and B need not be set before + entry. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If side = CblasLeft: + the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. + If side = CblasRight: + the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If side = CblasLeft: lda >= max(1, m) . + If side = CblasRight:lda >= max(1, n) . + + \param[in,out] B + B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. + */ +template< typename T > +void trmm( + CBLAS_ORDER layout, + CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + CBLAS_DIAG diag, + int64_t m, + int64_t n, + T alpha, + T const *A, int64_t lda, + T *B, int64_t ldb ) +{ + cblas_trmm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +/*! \brief Solve the Hermitian matrix-matrix multiply for arbitrary data types + + \b Purpose: + + HEMM performs solves one of the matrix-matrix operations for arbitrary data types + Data precisions supported include SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*A*B + beta*C + + or + + C := alpha*B*A + beta*C, + + where alpha is a scalar, A is an hermitian matrix + C and B are m by n matrices + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] side + side is enum CBLAS_SIDE + side specifies specifies whether the hermitian matrix A + appears on the left or right in the operation as follows: + side = CBLAS_SIDE::CblasLeft C := alpha*A*B + beta*C, + side = CBLAS_SIDE::CblasRight C := alpha*B*A + beta*C + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the upper or lower + triangular part of the hermitian matrix A is to be + referenced as follows: + uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of the + hermitian matrix is to be referenced. + uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of the + hermitian matrix is to be referenced. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix + C. m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix + C. n must be at least zero. + + \param[in] alpha + alpha is COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is COMPLEX/COMPLEX*16 array,dimension : + If side = CblasLeft: + the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. + If side = CblasRight: + the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If side = CblasLeft: lda >= max(1, m) . + If side = CblasRight:lda >= max(1, k) . + + \param[in] B + B is COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. + + \param[in] beta + beta is COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar beta. + If beta is zero, C need not be set on input + + \param[in,out] C + C is COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an ldc-by-n array [RowMajor: m-by-ldc]. + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the Leading dimension of C + ldc >= max(1, m) [RowMajor: ldc >= max(1, n)]. + */ +template< typename T > +void hemm( + CBLAS_ORDER layout, + CBLAS_SIDE side, + CBLAS_UPLO uplo, + int64_t m, int64_t n, + T alpha, + T const *A, int64_t lda, + T const *B, int64_t ldb, + T beta, + T *C, int64_t ldc ) +{ + cblas_hemm( layout, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +/*! \brief Solve the Symmetric matrix-matrix multiply for arbitrary data types + + \b Purpose: + + SYMM performs solves one of the matrix-matrix operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*A*B + beta*C + + or + + C := alpha*B*A + beta*C, + + where alpha is a scalar, A is an symmetric matrix + C and B are m by n matrices + + \param[in] layout + layout is enum CBLAS_ORDER + layout specifies Matrix storage as follows: + layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. + + \param[in] side + side is enum CBLAS_SIDE + side specifies specifies whether the symmetric matrix A + appears on the left or right in the operation as follows: + side = CBLAS_SIDE::CblasLeft C := alpha*A*B + beta*C, + side = CBLAS_SIDE::CblasRight C := alpha*B*A + beta*C + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the upper or lower + triangular part of the symmetric matrix A is to be + referenced as follows: + uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of the + symmetric matrix is to be referenced. + uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of the + symmetric matrix is to be referenced. + + \param[in] m + m is INTEGER + On entry, m specifies the number of rows of the matrix + C. m must be at least zero. + + \param[in] n + n is INTEGER + On entry, n specifies the number of columns of the matrix + C. n must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If side = CblasLeft: + the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. + If side = CblasRight: + the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If side = CblasLeft: lda >= max(1, m) . + If side = CblasRight:lda >= max(1, k) . + + \param[in] B + B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. + + \param[in] beta + beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar beta. + If beta is zero, C need not be set on input + + \param[in, out] C + C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + m-by-n , stored in an ldc-by-n array [RowMajor: m-by-ldc]. + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the Leading dimension of C + ldc >= max(1, m) [RowMajor: ldc >= max(1, n)]. + */ +template< typename T > +void symm( + CBLAS_ORDER layout, + CBLAS_SIDE side, + CBLAS_UPLO uplo, + int64_t m, int64_t n, + T alpha, + T const *A, int64_t lda, + T const *B, int64_t ldb, + T beta, + T *C, int64_t ldc ) +{ + cblas_symm( layout, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +/*! \brief Solve the Symmetric rank-k operations for arbitrary data types + + \b Purpose: + + SYRK performs one of the symmetric rank k operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*A*A**T + beta*C, + + or + + C := alpha*A**T*A + beta*C, + + where alpha and beta are scalars, C is an n by n symmetric matrix + and A is an n by k matrix in the first case and a k by n matrix + in the second case. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the upper or lower + triangular part of the array C is to be referenced + as follows: + uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C + is to be referenced. + uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C + is to be referenced. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be used as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans,C := alpha*A*A**T + beta*C. + trans = CBLAS_TRANSPOSE::CblasTrans,C := alpha*A**T*A + beta*C. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix C. n must be + at least zero. + + \param[in] k + k is INTEGER + If trans = CblasNoTrans: k is number of columns of the matrix A. + Otherwise: k is number of rows of the matrix A. + k must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If transA = CblasNoTrans: + n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. + Otherwise: + k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If transA = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. + Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. + + \param[in] beta + beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then C need not be set on input. + + \param[in,out] C + C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : + The n-by-n symmetric matrix C, + stored in an ldc-by-n array [RowMajor: n-by-ldc]. + On exit, the array C is overwritten by the lower/upper + triangular part of the updated matrix. + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the first dimension of C + ldc >= max(1, n) + */ +template< typename T > +void syrk( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + int64_t n, int64_t k, + T alpha, + T const *A, int64_t lda, + T beta, + T *C, int64_t ldc ) +{ + cblas_syrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +/*! \brief Solve the Symmetric rank 2k operations for arbitrary data types + + \b Purpose: + + SYR2K performs one of the symmetric rank 2k operations for arbitrary data types + Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, + SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*A*B**T + alpha*B*A**T + beta*C, + + or + + C := alpha*A**T*B + alpha*B**T*A + beta*C, + + where alpha and beta are scalars, C is an n by n symmetric matrix + and A and B are n by k matrices in the first case and k by n matrices + in the second case. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the upper or lower + triangular part of the array C is to be referenced + as follows: + uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C + is to be referenced. + uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C + is to be referenced. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be used as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans,C := alpha*A*B**T + alpha*B*A**T + beta*C. + trans = CBLAS_TRANSPOSE::CblasTrans, C := alpha*A**T*B + alpha*B**T*A + beta*C. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix C. n must be + at least zero. + + \param[in] k + k is INTEGER + If trans = CblasNoTrans: k is number of columns of the matrices A & B. + Otherwise: k is number of rows of the matrices A & B. + k must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If trans = CblasNoTrans: + n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. + Otherwise: + k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If trans = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. + Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. + + \param[in] B + B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : + If trans = CblasNoTrans: + n-by-k , stored in an ldb-by-k array [RowMajor: n-by-ldb]. + Otherwise: + k-by-n , stored in an ldb-by-n array [RowMajor: k-by-ldb] + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + If trans = CblasNoTrans: ldb >= max(1, n) [RowMajor: ldb >= max(1, k)]. + Otherwise: ldb >= max(1, k) [RowMajor: ldb >= max(1, n)]. + + \param[in] beta + beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then C need not be set on input. + + \param[in,out] C + C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : + The n-by-n symmetric matrix C, + stored in an ldc-by-n array [RowMajor: n-by-ldc]. + On exit, the array C is overwritten by the lower/upper + triangular part of the updated matrix. + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the first dimension of C + ldc >= max(1, n) + */ +template< typename T > +void syr2k( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + int64_t n, int64_t k, + T alpha, + T const *A, int64_t lda, + T const *B, int64_t ldb, + T beta, + T *C, int64_t ldc ) +{ + cblas_syr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); +} + +/*! \brief Solve the Hermitian rank k operations for arbitrary data types + + \b Purpose: + + HERK performs one of the hermitian rank k operations for arbitrary data types + Data precisions supported include SINGLE PRECISION COMPLEX, + DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*A*B**H + conjg( alpha )*B*A**H + beta*C, + + or + + C := alpha*A**H*B + conjg( alpha )*B**H*A + beta*C, + + where alpha and beta are real scalars, C is an n by n hermitian + matrix and A is an n by k matrix in the first case and + k by n matrix in the second case. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the upper or lower + triangular part of the array C is to be referenced + as follows: + uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C + is to be referenced. + uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C + is to be referenced. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be used as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, C := alpha*A*A**H + beta*C. + trans = CBLAS_TRANSPOSE::CblasConjTrans,C := alpha*A**H*A + beta*C. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix C. n must be + at least zero. + + \param[in] k + k is INTEGER + If trans = CblasNoTrans: k is number of columns of the matrix A. + Otherwise: k is number of rows of the matrix A. + k must be at least zero. + + \param[in] alpha + alpha is REAL/DOUBLE PRECISION + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is COMPLEX/COMPLEX*16 array,dimension : + If trans = CblasNoTrans: + n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. + Otherwise: + k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If trans = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. + Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. + + \param[in] beta + beta is REAL/DOUBLE PRECISION + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then C need not be set on input. + + \param[in,out] C + C is COMPLEX/COMPLEX*16 array, dimension : + The n-by-n Hermitian matrix C, + stored in an ldc-by-n array [RowMajor: n-by-ldc]. + On exit, the array C is overwritten by the lower/upper + triangular part of the updated matrix. + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the first dimension of C + ldc >= max(1, n) + */ +template< typename T > +void herk( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + int64_t n, int64_t k, + real_type alpha, + T const *A, int64_t lda, + real_type beta, + T *C, int64_t ldc ) +{ + cblas_herk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +/*! \brief Solve the Hermitian rank 2k operations for arbitrary data types + + \b Purpose: + + HER2K performs one of the hermitian rank 2k operations for arbitrary data types + Data precisions supported include SINGLE PRECISION COMPLEX, + DOUBLE PRECISION COMPLEX(COMPLEX*16) + + C := alpha*A*B**H + conjg( alpha )*B*A**H + beta*C, + + or + + C := alpha*A**H*B + conjg( alpha )*B**H*A + beta*C, + + where alpha and beta are scalars with beta real, C is an n by n + hermitian matrix and A and B are n by k matrices in the first case + and k by n matrices in the second case. + + \param[in] layout + layout is enum CBLAS_LAYOUT + layout specifies Matrix storage as follows: + layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. + + \param[in] uplo + uplo is enum CBLAS_UPLO + uplo specifies specifies whether the upper or lower + triangular part of the array C is to be referenced + as follows: + uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C + is to be referenced. + uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C + is to be referenced. + + \param[in] trans + trans is CBLAS_TRANSPOSE + On entry, trans specifies the operation to be used as follows: + trans = CBLAS_TRANSPOSE::CblasNoTrans, C := alpha*A*B**H + conjg( alpha )*B*A**H + beta*C. + trans = CBLAS_TRANSPOSE::CblasConjTrans,C := alpha*A**H*B + conjg( alpha )*B**H*A + beta*C. + + \param[in] n + n is INTEGER + On entry, n specifies the order of the matrix C. n must be + at least zero. + + \param[in] k + k is INTEGER + If trans = CblasNoTrans: k is number of columns of the matrices A & B. + Otherwise: k is number of rows of the matrices A & B. + k must be at least zero. + + \param[in] alpha + alpha is COMPLEX/COMPLEX*16 + On entry, alpha specifies the scalar alpha. + + \param[in] A + A is COMPLEX/COMPLEX*16 array,dimension : + If trans = CblasNoTrans: + n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. + Otherwise: + k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. + + \param[in] lda + lda is INTEGER + On entry, lda specifies the Leading dimension of A + If trans = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. + Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. + + \param[in] B + B is COMPLEX/COMPLEX*16 array,dimension : + If trans = CblasNoTrans: + n-by-k , stored in an ldb-by-k array [RowMajor: n-by-ldb]. + Otherwise: + k-by-n , stored in an ldb-by-n array [RowMajor: k-by-ldb] + + \param[in] ldb + ldb is INTEGER + On entry, ldb specifies the Leading dimension of B + If trans = CblasNoTrans: ldb >= max(1, n) [RowMajor: ldb >= max(1, k)]. + Otherwise: ldb >= max(1, k) [RowMajor: ldb >= max(1, n)]. + + \param[in] beta + beta is REAL/DOUBLE PRECISION + On entry, beta specifies the scalar alpha.When beta is + supplied as zero then C need not be set on input. + + \param[in,out] C + C is COMPLEX/COMPLEX*16 array, dimension : + The n-by-n Hermitian matrix C, + stored in an ldc-by-n array [RowMajor: n-by-ldc]. + On exit, the array C is overwritten by the lower/upper + triangular part of the updated matrix. + + \param[in] ldc + ldc is INTEGER + On entry, ldc specifies the first dimension of C + ldc >= max(1, n) + */ +template< typename T > +void her2k( + CBLAS_ORDER layout, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, + int64_t n, int64_t k, + T alpha, + T const *A, int64_t lda, + T const *B, int64_t ldb, + real_type beta, + T *C, int64_t ldc ) +{ + cblas_her2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); +} + +} // namespace blis +#endif // #ifndef BLIS_HH diff --git a/vendor/cpp/cblas.hh b/vendor/cpp/cblas.hh new file mode 100644 index 0000000000..b656ed28e1 --- /dev/null +++ b/vendor/cpp/cblas.hh @@ -0,0 +1,1705 @@ +/****************************************************************************** +* Copyright (c) 2019 - present Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +*******************************************************************************/ + +/*! @file cblas.hh + * cblas.hh defines all the overloaded CPP functions to be invoked from + * template interfaces + * */ +#ifndef CBLAS_HH +#define CBLAS_HH + +extern "C" { +#include +} + +#include + +namespace blis{ + +template< typename... Types > struct real_type_traits; + +//define real_type<> type alias +template< typename... Types > +using real_type = typename real_type_traits< Types... >::real_t; + +// for one type +template< typename T > +struct real_type_traits +{ + using real_t = T; +}; + +// for one complex type, strip complex +template< typename T > +struct real_type_traits< std::complex > +{ + using real_t = T; +}; + +// ============================================================================= +// Level 1 BLAS +// ----------------------------------------------------------------------------- +inline void +cblas_rotg( + float *a, float *b, + float *c, float *s ) +{ + cblas_srotg( a, b, c, s ); +} + +inline void +cblas_rotg( + double *a, double *b, + double *c, double *s ) +{ + cblas_drotg( a, b, c, s ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_rotmg( + float *d1, float *d2, float *x1, float y1, float param[5] ) +{ + cblas_srotmg( d1, d2, x1, y1, param ); +} + +inline void +cblas_rotmg( + double *d1, double *d2, double *x1, double y1, double param[5] ) +{ + cblas_drotmg( d1, d2, x1, y1, param ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_rot( + int n, + float *x, int incx, + float *y, int incy, + float c, float s ) +{ + cblas_srot( n, x, incx, y, incy, c, s ); +} + +inline void +cblas_rot( + int n, + double *x, int incx, + double *y, int incy, + double c, double s ) +{ + cblas_drot( n, x, incx, y, incy, c, s ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_rotm( + int n, + float *x, int incx, + float *y, int incy, + const float p[5] ) +{ + cblas_srotm( n, x, incx, y, incy, p ); +} + +inline void +cblas_rotm( + int n, + double *x, int incx, + double *y, int incy, + const double p[5] ) +{ + cblas_drotm( n, x, incx, y, incy, p ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_swap( + int n, + float* x, int incx, + float* y, int incy ) +{ + cblas_sswap( n, x, incx, y, incy ); +} + +inline void +cblas_swap( + int n, + double* x, int incx, + double* y, int incy ) +{ + cblas_dswap( n, x, incx, y, incy ); +} + +inline void +cblas_swap( + int n, + std::complex* x, int incx, + std::complex* y, int incy ) +{ + cblas_cswap( n, x, incx, y, incy ); +} + +inline void +cblas_swap( + int n, + std::complex* x, int incx, + std::complex* y, int incy ) +{ + cblas_zswap( n, x, incx, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_scal( + int n, float alpha, + float* x, int incx ) +{ + cblas_sscal( n, alpha, x, incx ); +} + +inline void +cblas_scal( + int n, double alpha, + double* x, int incx ) +{ + cblas_dscal( n, alpha, x, incx ); +} + +inline void +cblas_scal( + int n, std::complex alpha, + std::complex* x, int incx ) +{ + cblas_cscal( n, &alpha, x, incx ); +} + +inline void +cblas_scal( + int n, std::complex alpha, + std::complex* x, int incx ) +{ + cblas_zscal( n, &alpha, x, incx ); +} + +inline void +cblas_scal( + int n, float alpha, + std::complex* x, int incx ) +{ + cblas_csscal( n, alpha, x, incx ); +} + +inline void +cblas_scal( + int n, double alpha, + std::complex* x, int incx ) +{ + cblas_zdscal( n, alpha, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_copy( + int n, + float const *x, int incx, + float* y, int incy ) +{ + cblas_scopy( n, x, incx, y, incy ); +} + +inline void +cblas_copy( + int n, + double const *x, int incx, + double* y, int incy ) +{ + cblas_dcopy( n, x, incx, y, incy ); +} + +inline void +cblas_copy( + int n, + std::complex const *x, int incx, + std::complex* y, int incy ) +{ + cblas_ccopy( n, x, incx, y, incy ); +} + +inline void +cblas_copy( + int n, + std::complex const *x, int incx, + std::complex* y, int incy ) +{ + cblas_zcopy( n, x, incx, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_axpy( + int n, float alpha, + float const *x, int incx, + float* y, int incy ) +{ + cblas_saxpy( n, alpha, x, incx, y, incy ); +} + +inline void +cblas_axpy( + int n, double alpha, + double const *x, int incx, + double* y, int incy ) +{ + cblas_daxpy( n, alpha, x, incx, y, incy ); +} + +inline void +cblas_axpy( + int n, std::complex alpha, + std::complex const *x, int incx, + std::complex* y, int incy ) +{ + cblas_caxpy( n, &alpha, x, incx, y, incy ); +} + +inline void +cblas_axpy( + int n, std::complex alpha, + std::complex const *x, int incx, + std::complex* y, int incy ) +{ + cblas_zaxpy( n, &alpha, x, incx, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline float +cblas_dot( + int n, + float const *x, int incx, + float const *y, int incy ) +{ + return cblas_sdot( n, x, incx, y, incy ); +} + +inline double +cblas_dot( + int n, + double const *x, int incx, + double const *y, int incy ) +{ + return cblas_ddot( n, x, incx, y, incy ); +} +// ----------------------------------------------------------------------------- +inline std::complex +cblas_dotu( + int n, + std::complex const *x, int incx, + std::complex const *y, int incy ) +{ + std::complex result; + cblas_cdotu_sub( n, x, incx, y, incy, &result ); + return result; +} + +inline std::complex +cblas_dotu( + int n, + std::complex const *x, int incx, + std::complex const *y, int incy ) +{ + std::complex result; + cblas_zdotu_sub( n, x, incx, y, incy, &result ); + return result; +} + +// ----------------------------------------------------------------------------- +inline std::complex +cblas_dotc( + int n, + std::complex const *x, int incx, + std::complex const *y, int incy ) +{ + std::complex result; + cblas_cdotc_sub( n, x, incx, y, incy, &result ); + return result; +} + +inline std::complex +cblas_dotc( + int n, + std::complex const *x, int incx, + std::complex const *y, int incy ) +{ + std::complex result; + cblas_zdotc_sub( n, x, incx, y, incy, &result ); + return result; +} + +// ----------------------------------------------------------------------------- +inline int +cblas_iamax( + int n, float const *x, int incx ) +{ + return cblas_isamax( n, x, incx ); +} + +inline int +cblas_iamax( + int n, double const *x, int incx ) +{ + return cblas_idamax( n, x, incx ); +} + +inline int +cblas_iamax( + int n, std::complex const *x, int incx ) +{ + return cblas_icamax( n, x, incx ); +} + +inline int +cblas_iamax( + int n, std::complex const *x, int incx ) +{ + return cblas_izamax( n, x, incx ); +} + + +// ----------------------------------------------------------------------------- +inline float +cblas_nrm2( + int n, float const *x, int incx ) +{ + return cblas_snrm2( n, x, incx ); +} + +inline double +cblas_nrm2( + int n, double const *x, int incx ) +{ + return cblas_dnrm2( n, x, incx ); +} + +inline float +cblas_nrm2( + int n, std::complex const *x, int incx ) +{ + return cblas_scnrm2( n, x, incx ); +} + +inline double +cblas_nrm2( + int n, std::complex const *x, int incx ) +{ + return cblas_dznrm2( n, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline float +cblas_asum( + int n, float const *x, int incx ) +{ + return cblas_sasum( n, x, incx ); +} + +inline double +cblas_asum( + int n, double const *x, int incx ) +{ + return cblas_dasum( n, x, incx ); +} + +inline float +cblas_asum( + int n, std::complex const *x, int incx ) +{ + return cblas_scasum( n, x, incx ); +} + +inline double +cblas_asum( + int n, std::complex const *x, int incx ) +{ + return cblas_dzasum( n, x, incx ); +} +// ============================================================================= +// Level 2 BLAS + +// ----------------------------------------------------------------------------- +inline void +cblas_gemv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, + float alpha, + float const *A, int lda, + float const *x, int incx, + float beta, + float* y, int incy ) +{ + cblas_sgemv( layout, trans, m, n, + alpha, A, lda, x, incx, beta, y, incy ); +} + +inline void +cblas_gemv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, + double alpha, + double const *A, int lda, + double const *x, int incx, + double beta, + double* y, int incy ) +{ + cblas_dgemv( layout, trans, m, n, + alpha, A, lda, x, incx, beta, y, incy ); +} + +inline void +cblas_gemv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_cgemv( layout, trans, m, n, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +inline void +cblas_gemv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_zgemv( layout, trans, m, n, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} +inline void +cblas_gbmv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, + int m, int n, int kl, int ku, + float alpha, + float const *A, int lda, + float const *x, int incx, + float beta, + float* y, int incy ) +{ + cblas_sgbmv( layout, trans, m, n, kl, ku, + alpha, A, lda, x, incx, beta, y, incy ); +} + +inline void +cblas_gbmv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, + int m, int n, int kl, int ku, + double alpha, + double const *A, int lda, + double const *x, int incx, + double beta, + double* y, int incy ) +{ + cblas_dgbmv( layout, trans, m, n, kl, ku, + alpha, A, lda, x, incx, beta, y, incy ); +} + +inline void +cblas_gbmv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, + int m, int n, int kl, int ku, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_cgbmv( layout, trans, m, n, kl, ku, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +inline void +cblas_gbmv( + CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, + int m, int n, int kl, int ku, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_zgbmv( layout, trans, m, n, kl, ku, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_hemv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_chemv( layout, uplo, n, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +inline void +cblas_hemv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_zhemv( layout, uplo, n, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_hbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_chbmv( layout, uplo, n, k, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +inline void +cblas_hbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_zhbmv( layout, uplo, n, k, + &alpha, A, lda, x, incx, + &beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_hpmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *Ap, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_chpmv( layout, uplo, n, + &alpha, Ap, x, incx, + &beta, y, incy ); +} + +inline void +cblas_hpmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *Ap, + std::complex const *x, int incx, + std::complex beta, + std::complex* y, int incy ) +{ + cblas_zhpmv( layout, uplo, n, + &alpha, Ap, x, incx, + &beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_symv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + float const *A, int lda, + float const *x, int incx, + float beta, + float* y, int incy ) +{ + cblas_ssymv( layout, uplo, n, + alpha, A, lda, x, incx, beta, y, incy ); +} + +inline void +cblas_symv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + double const *A, int lda, + double const *x, int incx, + double beta, + double* y, int incy ) +{ + cblas_dsymv( layout, uplo, n, + alpha, A, lda, x, incx, beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_sbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, + float alpha, + float const *A, int lda, + float const *x, int incx, + float beta, + float* y, int incy ) +{ + cblas_ssbmv( layout, uplo, n, k, + alpha, A, lda, x, incx, beta, y, incy ); +} + +inline void +cblas_sbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, + double alpha, + double const *A, int lda, + double const *x, int incx, + double beta, + double* y, int incy ) +{ + cblas_dsbmv( layout, uplo, n, k, + alpha, A, lda, x, incx, beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_spmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + float const *Ap, + float const *x, int incx, + float beta, + float* y, int incy ) +{ + cblas_sspmv( layout, uplo, n, + alpha, Ap, x, incx, beta, y, incy ); +} + +inline void +cblas_spmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + double const *Ap, + double const *x, int incx, + double beta, + double* y, int incy ) +{ + cblas_dspmv( layout, uplo, n, + alpha, Ap, x, incx, beta, y, incy ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_trmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + float const *A, int lda, + float* x, int incx ) +{ + cblas_strmv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +inline void +cblas_trmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + double const *A, int lda, + double* x, int incx ) +{ + cblas_dtrmv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +inline void +cblas_trmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ctrmv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +inline void +cblas_trmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ztrmv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_tbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + float const *A, int lda, + float* x, int incx ) +{ + cblas_stbmv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +inline void +cblas_tbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + double const *A, int lda, + double* x, int incx ) +{ + cblas_dtbmv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +inline void +cblas_tbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ctbmv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +inline void +cblas_tbmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ztbmv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_tpmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + float const *Ap, + float* x, int incx ) +{ + cblas_stpmv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +inline void +cblas_tpmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + double const *Ap, + double* x, int incx ) +{ + cblas_dtpmv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +inline void +cblas_tpmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *Ap, + std::complex* x, int incx ) +{ + cblas_ctpmv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +inline void +cblas_tpmv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *Ap, + std::complex* x, int incx ) +{ + cblas_ztpmv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_trsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + float const *A, int lda, + float* x, int incx ) +{ + cblas_strsv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +inline void +cblas_trsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + double const *A, int lda, + double* x, int incx ) +{ + cblas_dtrsv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +inline void +cblas_trsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ctrsv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +inline void +cblas_trsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ztrsv( layout, uplo, trans, diag, n, + A, lda, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_tbsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + float const *A, int lda, + float* x, int incx ) +{ + cblas_stbsv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +inline void +cblas_tbsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + double const *A, int lda, + double* x, int incx ) +{ + cblas_dtbsv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +inline void +cblas_tbsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ctbsv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +inline void +cblas_tbsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int n, int k, + std::complex const *A, int lda, + std::complex* x, int incx ) +{ + cblas_ztbsv( layout, uplo, trans, diag, n, k, + A, lda, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_tpsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + float const *Ap, + float* x, int incx ) +{ + cblas_stpsv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +inline void +cblas_tpsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + double const *Ap, + double* x, int incx ) +{ + cblas_dtpsv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +inline void +cblas_tpsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *Ap, + std::complex* x, int incx ) +{ + cblas_ctpsv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +inline void +cblas_tpsv( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, + std::complex const *Ap, + std::complex* x, int incx ) +{ + cblas_ztpsv( layout, uplo, trans, diag, n, + Ap, x, incx ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_ger( + CBLAS_ORDER layout, int m, int n, + float alpha, + float const *x, int incx, + float const *y, int incy, + float* A, int lda ) +{ + cblas_sger( layout, m, n, alpha, x, incx, y, incy, A, lda ); +} + +inline void +cblas_ger( + CBLAS_ORDER layout, int m, int n, + double alpha, + double const *x, int incx, + double const *y, int incy, + double* A, int lda ) +{ + cblas_dger( layout, m, n, alpha, x, incx, y, incy, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_geru( + CBLAS_ORDER layout, int m, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* A, int lda ) +{ + cblas_cgeru( layout, m, n, &alpha, + x, incx, y, incy, A, lda ); +} + +inline void +cblas_geru( + CBLAS_ORDER layout, int m, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* A, int lda ) +{ + cblas_zgeru( layout, m, n, &alpha, + x, incx, y, incy, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_gerc( + CBLAS_ORDER layout, int m, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* A, int lda ) +{ + cblas_cgerc( layout, m, n, &alpha, + x, incx, y, incy, A, lda ); +} + +inline void +cblas_gerc( + CBLAS_ORDER layout, int m, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* A, int lda ) +{ + cblas_zgerc( layout, m, n, &alpha, + x, incx, y, incy, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_her( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + std::complex const *x, int incx, + std::complex* A, int lda ) +{ + cblas_cher( layout, uplo, n, alpha, x, incx, A, lda ); +} + +inline void +cblas_her( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + std::complex const *x, int incx, + std::complex* A, int lda ) +{ + cblas_zher( layout, uplo, n, alpha, x, incx, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_hpr( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + std::complex const *x, int incx, + std::complex* Ap ) +{ + cblas_chpr( layout, uplo, n, alpha, x, incx, Ap ); +} + +inline void +cblas_hpr( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + std::complex const *x, int incx, + std::complex* Ap ) +{ + cblas_zhpr( layout, uplo, n, alpha, x, incx, Ap ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_her2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* A, int lda ) +{ + cblas_cher2( layout, uplo, n, &alpha, x, incx, y, incy, A, lda ); +} + +inline void +cblas_her2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* A, int lda ) +{ + cblas_zher2( layout, uplo, n, &alpha, x, incx, y, incy, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_hpr2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* Ap ) +{ + cblas_chpr2( layout, uplo, n, &alpha, x, incx, y, incy, Ap ); +} + +inline void +cblas_hpr2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex const *y, int incy, + std::complex* Ap ) +{ + cblas_zhpr2( layout, uplo, n, &alpha, x, incx, y, incy, Ap ); +} +// ----------------------------------------------------------------------------- +inline void +cblas_syr( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + float const *x, int incx, + float* A, int lda ) +{ + cblas_ssyr( layout, uplo, n, alpha, x, incx, A, lda ); +} + +inline void +cblas_syr( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + double const *x, int incx, + double* A, int lda ) +{ + cblas_dsyr( layout, uplo, n, alpha, x, incx, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_spr( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + float const *x, int incx, + float* Ap ) +{ + cblas_sspr( layout, uplo, n, alpha, x, incx, Ap ); +} + +inline void +cblas_spr( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + double const *x, int incx, + double* Ap ) +{ + cblas_dspr( layout, uplo, n, alpha, x, incx, Ap ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_syr2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + float const *x, int incx, + float const *y, int incy, + float* A, int lda ) +{ + cblas_ssyr2( layout, uplo, n, alpha, x, incx, y, incy, A, lda ); +} + +inline void +cblas_syr2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + double const *x, int incx, + double const *y, int incy, + double* A, int lda ) +{ + cblas_dsyr2( layout, uplo, n, alpha, x, incx, y, incy, A, lda ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_spr2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + float alpha, + float const *x, int incx, + float const *y, int incy, + float* Ap ) +{ + cblas_sspr2( layout, uplo, n, alpha, x, incx, y, incy, Ap ); +} + +inline void +cblas_spr2( + CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, + double alpha, + double const *x, int incx, + double const *y, int incy, + double* Ap ) +{ + cblas_dspr2( layout, uplo, n, alpha, x, incx, y, incy, Ap ); +} + +// ============================================================================= +// Level 3 BLAS + +// ----------------------------------------------------------------------------- +inline void +cblas_gemm( + CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int m, int n, int k, + float alpha, + float const *A, int lda, + float const *B, int ldb, + float beta, + float* C, int ldc ) +{ + cblas_sgemm( layout, transA, transB, m, n, k, + alpha, A, lda, B, ldb, + beta, C, ldc ); +} + +inline void +cblas_gemm( + CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int m, int n, int k, + double alpha, + double const *A, int lda, + double const *B, int ldb, + double beta, + double* C, int ldc ) +{ + cblas_dgemm( layout, transA, transB, m, n, k, + alpha, A, lda, B, ldb, + beta, C, ldc ); +} + +inline void +cblas_gemm( + CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int m, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_cgemm( layout, transA, transB, m, n, k, + &alpha, A, lda, B, ldb, + &beta, C, ldc ); +} + +inline void +cblas_gemm( + CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int m, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_zgemm( layout, transA, transB, m, n, k, + &alpha, A, lda, B, ldb, + &beta, C, ldc ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_trmm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + float alpha, + float const *A, int lda, + float *B, int ldb ) +{ + cblas_strmm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +inline void +cblas_trmm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + double alpha, + double const *A, int lda, + double *B, int ldb ) +{ + cblas_dtrmm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +inline void +cblas_trmm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex *B, int ldb ) +{ + cblas_ctrmm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); +} + +inline void +cblas_trmm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex *B, int ldb ) +{ + cblas_ztrmm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); +} + + +// ----------------------------------------------------------------------------- +inline void +cblas_trsm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + float alpha, + float const *A, int lda, + float *B, int ldb ) +{ + cblas_strsm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +inline void +cblas_trsm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + double alpha, + double const *A, int lda, + double *B, int ldb ) +{ + cblas_dtrsm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +inline void +cblas_trsm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex *B, int ldb ) +{ + cblas_ctrsm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); +} + +inline void +cblas_trsm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex *B, int ldb ) +{ + cblas_ztrsm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_hemm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + float alpha, + float const *A, int lda, + float const *B, int ldb, + float beta, + float* C, int ldc ) +{ + cblas_ssymm( layout, side, uplo, m, n, + alpha, A, lda, B, ldb, + beta, C, ldc ); +} + +inline void +cblas_hemm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + double alpha, + double const *A, int lda, + double const *B, int ldb, + double beta, + double* C, int ldc ) +{ + cblas_dsymm( layout, side, uplo, m, n, + alpha, A, lda, B, ldb, + beta, C, ldc ); +} + +inline void +cblas_hemm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_chemm( layout, side, uplo, m, n, + &alpha, A, lda, B, ldb, + &beta, C, ldc ); +} + +inline void +cblas_hemm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_zhemm( layout, side, uplo, m, n, + &alpha, A, lda, B, ldb, + &beta, C, ldc ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_symm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + float alpha, + float const *A, int lda, + float const *B, int ldb, + float beta, + float* C, int ldc ) +{ + cblas_ssymm( layout, side, uplo, m, n, + alpha, A, lda, B, ldb, + beta, C, ldc ); +} + +inline void +cblas_symm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + double alpha, + double const *A, int lda, + double const *B, int ldb, + double beta, + double* C, int ldc ) +{ + cblas_dsymm( layout, side, uplo, m, n, + alpha, A, lda, B, ldb, + beta, C, ldc ); +} + +inline void +cblas_symm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_csymm( layout, side, uplo, m, n, + &alpha, A, lda, B, ldb, + &beta, C, ldc ); +} + +inline void +cblas_symm( + CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + int m, int n, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_zsymm( layout, side, uplo, m, n, + &alpha, A, lda, B, ldb, + &beta, C, ldc ); +} + + +// ----------------------------------------------------------------------------- +inline void +cblas_syrk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + float alpha, + float const *A, int lda, + float beta, + float* C, int ldc ) +{ + cblas_ssyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +inline void +cblas_syrk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + double alpha, + double const *A, int lda, + double beta, + double* C, int ldc ) +{ + cblas_dsyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +inline void +cblas_syrk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_csyrk( layout, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc ); +} + +inline void +cblas_syrk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_zsyrk( layout, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_herk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + float alpha, + float const *A, int lda, + float beta, + float* C, int ldc ) +{ + cblas_ssyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +inline void +cblas_herk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + double alpha, + double const *A, int lda, + double beta, + double* C, int ldc ) +{ + cblas_dsyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +inline void +cblas_herk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + float alpha, // note: real + std::complex const *A, int lda, + float beta, // note: real + std::complex* C, int ldc ) +{ + cblas_cherk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +inline void +cblas_herk( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + double alpha, // note: real + std::complex const *A, int lda, + double beta, // note: real + std::complex* C, int ldc ) +{ + cblas_zherk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_syr2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + float alpha, + float const *A, int lda, + float const *B, int ldb, + float beta, + float* C, int ldc ) +{ + cblas_ssyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); +} + +inline void +cblas_syr2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + double alpha, + double const *A, int lda, + double const *B, int ldb, + double beta, + double* C, int ldc ) +{ + cblas_dsyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); +} + +inline void +cblas_syr2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_csyr2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc ); +} + +inline void +cblas_syr2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + std::complex beta, + std::complex* C, int ldc ) +{ + cblas_zsyr2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc ); +} + +// ----------------------------------------------------------------------------- +inline void +cblas_her2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + float alpha, + float const *A, int lda, + float const *B, int ldb, + float beta, + float* C, int ldc ) +{ + cblas_ssyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); +} + +inline void +cblas_her2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + double alpha, + double const *A, int lda, + double const *B, int ldb, + double beta, + double* C, int ldc ) +{ + cblas_dsyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); +} + +inline void +cblas_her2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + float beta, // note: real + std::complex* C, int ldc ) +{ + cblas_cher2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, beta, C, ldc ); +} + +inline void +cblas_her2k( + CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, + std::complex alpha, + std::complex const *A, int lda, + std::complex const *B, int ldb, + double beta, // note: real + std::complex* C, int ldc ) +{ + cblas_zher2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, beta, C, ldc ); +} +}//namespace blis + +#endif // #ifndef CBLAS_HH diff --git a/vendor/testcpp/Makefile b/vendor/testcpp/Makefile new file mode 100644 index 0000000000..01506c9966 --- /dev/null +++ b/vendor/testcpp/Makefile @@ -0,0 +1,208 @@ +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + blis \ + clean cleanx + + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + + +# +# --- BLAS and LAPACK implementations ------------------------------------------ +# + +# BLIS library and header path. This is simply wherever it was installed. +#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib +#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis + +# BLIS library. +#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a + +# BLAS library path(s). This is where the BLAS libraries reside. +BLAS_LIB_PATH := $(HOME)/flame/lib + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +CPP_SRC_PATH := ../cpp/ +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c)) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +CXX = g++ + +# Use the CFLAGS for the configuration family. +override CFLAGS += $(call get-sandbox-cxxflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS +#CFLAGS = -O0 -g -Wall +#CFLAGS += -I$(INC_PATH) +override CFLAGS += -I$(TEST_SRC_PATH) +override CFLAGS += -I$(CPP_SRC_PATH) + +LINKER = $(CXX) + +# Locate the libblis library to which we will link. +LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + + + +# +# --- Targets/rules ------------------------------------------------------------ +# + +# Complete list of possible targets when defining 'all': +# +# blis +# +all: blis + + +blis: test_asum_blis.x \ + test_axpy_blis.x \ + test_copy_blis.x \ + test_dot_blis.x \ + test_dotc_blis.x \ + test_gbmv_blis.x \ + test_gemm_blis.x \ + test_gemv_blis.x \ + test_ger_blis.x \ + test_gerc_blis.x \ + test_geru_blis.x \ + test_hemm_blis.x \ + test_hemv_blis.x \ + test_her2_blis.x \ + test_her_blis.x \ + test_herk_blis.x \ + test_hpr2_blis.x \ + test_hpr_blis.x \ + test_nrm2_blis.x \ + test_rot_blis.x \ + test_rotg_blis.x \ + test_rotm_blis.x \ + test_rotmg_blis.x \ + test_scal_blis.x \ + test_sdsdot_blis.x \ + test_spr2_blis.x \ + test_spr_blis.x \ + test_swap_blis.x \ + test_symm_blis.x \ + test_syr2_blis.x \ + test_syr2k_blis.x \ + test_syr_blis.x \ + test_syrk_blis.x \ + test_tbmv_blis.x \ + test_tbsv_blis.x \ + test_tpmv_blis.x \ + test_tpsv_blis.x \ + test_trmm_blis.x \ + test_trsm_blis.x \ + test_trsv_blis.x + + + +# --Object file rules -- + +$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.cc + $(CXX) $(CFLAGS) -c $< -o $@ + +test_%_blis.o: test_%.cc + @$(CXX) $(CFLAGS) -c $< -o $@ + + +# -- Executable file rules -- + +test_%_blis.x: test_%_blis.o $(LIBBLIS_LINK) + @$(LINKER) $^ $(LIBBLIS_LINK) $(LDFLAGS) -o $@ + ./$@ + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.o *.x + diff --git a/vendor/testcpp/test.hh b/vendor/testcpp/test.hh new file mode 100644 index 0000000000..b1be412d64 --- /dev/null +++ b/vendor/testcpp/test.hh @@ -0,0 +1,219 @@ +/* + * -------------------------------------------------------------------------- + * BLISLAB + * -------------------------------------------------------------------------- + * Copyright (C) 2016, The University of Texas at Austin + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * - Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * - Neither the name of The University of Texas nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * test.hh + * + * + * Purpose: + * this header file contains all function prototypes. + * + * Todo: + * + * + * Modification: + * + * + * */ + + +#ifndef TEST_HH +#define TEST_HH + +#include + +#include +#include + +using namespace std; +#define min( i, j ) ( (i)<(j) ? (i): (j) ) + +#define A( i, j ) A[ (j)*lda + (i) ] +#define A_ref( i, j ) A_ref[ (j)*lda_ref + (i) ] + +#define B( i, j ) B[ (j)*ldb + (i) ] +#define B_ref( i, j ) B_ref[ (j)*ldb_ref + (i) ] + +#define C( i, j ) C[ (j)*ldc + (i) ] +#define C_ref( i, j ) C_ref[ (j)*ldc_ref + (i) ] + +#define X( i ) X[ incx + (i) ] +#define X_ref( i, j ) X_ref[ (j)*incx_ref + (i) + +#define Y( i ) Y[ incy + (i) ] +#define Y_ref( i ) Y_ref[ incy_ref + (i) ]\ + +// Allocate memory and initialise memory with random values +void allocate_init_buffer(int *aIn, int m, int n) +{ + aIn = new int [m*n]; + for ( int i = 0; i < m*n; i ++ ) { + aIn[ i ] = ((int) rand() / ((int) RAND_MAX / 2.0)) - 1.0; + } +} + +void allocate_init_buffer(float *&aIn, int m, int n) +{ + aIn = new float [m*n]; + for ( int i = 0; i < m*n; i ++ ) { + aIn[ i ] = ((float) rand() / ((float) RAND_MAX / 2.0)) - 1.0; + } +} + +void allocate_init_buffer(double *&aIn, int m, int n) +{ + aIn = new double [m*n]; + for ( int i = 0; i < m*n; i ++ ) { + aIn[ i ] = ((double) rand() / ((double) RAND_MAX / 2.0)) - 1.0; + } +} +void allocate_init_buffer(complex *&aIn, int m, int n) +{ + aIn = new complex [m*n]; + for ( int i = 0; i < m*n; i ++ ) { + float real = ((float) rand() / ((float) RAND_MAX / 2.0)) - 1.0; + float imag = ((float) rand() / ((float) RAND_MAX / 2.0)) - 1.0; + aIn[i] = {real,imag}; + } +} +void allocate_init_buffer(complex *&aIn, int m, int n) +{ + aIn = new complex [m*n]; + for ( int i = 0; i < m*n; i ++ ) { + double real = ((double) rand() / ((double) RAND_MAX / 2.0)) - 1.0; + double imag = ((double) rand() / ((double) RAND_MAX / 2.0)) - 1.0; + aIn[i] = {real,imag}; + } +} + +template< typename T > +void copy_buffer(T *aSrc, T *&aDest, int m, int n) +{ + aDest = new T [m*n]; + for ( int i = 0; i < m*n; i ++ ) { + aDest[i] = aSrc[i]; + } +} + +template< typename T > +int computeErrorM( + int lda, + int lda_ref, + int m, + int n, + T *A, + T *A_ref + ) +{ + + int i, j; + int ret = 0; + for ( i = 0; i < m; i ++ ) { + for ( j = 0; j < n; j ++ ) { + if ( (fabs (A( i, j )) - fabs( A_ref( i, j ))) > 0.0000001 ) { + cout << A(i,j) << A_ref(i,j); + ret = 1; + break; + } + } + } + return ret; + +} + + + + template< typename T > + int computeErrorV( + int incy, + int incy_ref, + int n, + T *Y, + T *Y_ref + ) + { + int i; + int ret = 0; + for ( i = 0; i < n; i ++ ) { + if ( (fabs( Y_ref[ i ]) - fabs(Y[ i ] ) ) > 0.00001) { + cout << Y[i] << Y_ref[i]; + ret = 1; + break; + } + } + + return ret; + + } + +/* + *printing matix and vector + * + */ + +template +void printmatrix( + T *A, + int lda, + int m, + int n, + char *func_str + ) +{ + int i, j; + cout << func_str <<"\n"; + for ( i = 0; i < m; i ++ ) { + for ( j = 0; j < n; j ++ ) { + cout<< A[j * lda + i]<<" "; + } + printf("\n"); + } + printf("\n"); +} + +template +void printvector( + T *X, + int m, + char *func_str + ) + { + int i; + cout << func_str <<"\n"; + for ( i = 0; i < m; i ++ ) { + cout<< X[i]<<" "; + cout<<"\n"; + } + printf("\n"); + + } + + +#endif diff --git a/vendor/testcpp/test.sh b/vendor/testcpp/test.sh new file mode 100644 index 0000000000..6d06c867bb --- /dev/null +++ b/vendor/testcpp/test.sh @@ -0,0 +1,46 @@ + +echo Build BLIS CPP Template tests +make clean +make + +echo Run tests +./test_asum_blis.x +./test_axpy_blis.x +./test_copy_blis.x +./test_dot_blis.x +./test_dotc_blis.x +./test_gbmv_blis.x +./test_gemm_blis.x +./test_gemv_blis.x +./test_ger_blis.x +./test_gerc_blis.x +./test_geru_blis.x +./test_hemm_blis.x +./test_hemv_blis.x +./test_her2_blis.x +./test_her_blis.x +./test_herk_blis.x +./test_hpr2_blis.x +./test_hpr_blis.x +./test_nrm2_blis.x +./test_rot_blis.x +./test_rotg_blis.x +./test_rotm_blis.x +./test_rotmg_blis.x +./test_scal_blis.x +./test_sdsdot_blis.x +./test_spr2_blis.x +./test_spr_blis.x +./test_swap_blis.x +./test_symm_blis.x +./test_syr2_blis.x +./test_syr2k_blis.x +./test_syr_blis.x +./test_syrk_blis.x +./test_tbmv_blis.x +./test_tbsv_blis.x +./test_tpmv_blis.x +./test_tpsv_blis.x +./test_trmm_blis.x +./test_trsm_blis.x +./test_trsv_blis.x diff --git a/vendor/testcpp/test_asum.cc b/vendor/testcpp/test_asum.cc new file mode 100644 index 0000000000..948f4250fd --- /dev/null +++ b/vendor/testcpp/test_asum.cc @@ -0,0 +1,127 @@ +/* + + BLISPP + C++ test driver for BLIS CPP asum routine and reference blis asum routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 6 +#define ALPHA 0.5 + +template< typename T, typename TR> +void ref_asum(int64_t n, + T *X, + TR *asum + ) +{ + obj_t obj_x; + obj_t obj_asum; + num_t dt, dtR; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + if(is_same::value) + dtR = BLIS_FLOAT; + else if(is_same::value) + dtR = BLIS_DOUBLE; + + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dtR, 1, 1, asum, 1, 1,&obj_asum ); + + bli_asumv(&obj_x, &obj_asum); + +} +template< typename T, typename TR> +void test_asum() +{ + + T *X, *X_ref; + TR asum, asum_ref; + int n; + int incx; + + n = N; + incx = 1; + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + copy_buffer(X, X_ref , n ,1); + +#ifdef PRINT + printvector(X, n,(char *) "X"); +#endif + + asum = blis::asum( + n, + X, + incx + ); + +#ifdef PRINT + cout<< "Sum of all values in Vector X: " << asum << "\n"; +#endif + + ref_asum(n, X_ref, &asum_ref ); + +#ifdef PRINT + cout<< "Ref Sum of all values in Vector X: " << asum_ref << "\n"; +#endif + if(computeErrorV(incx, incx, 1, &asum, &asum_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + delete[]( X ); + delete[]( X_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_asum( ); + test_asum( ); + test_asum, float>( ); + test_asum, double>( ); + return 0; + +} diff --git a/vendor/testcpp/test_axpy.cc b/vendor/testcpp/test_axpy.cc new file mode 100644 index 0000000000..45035198c3 --- /dev/null +++ b/vendor/testcpp/test_axpy.cc @@ -0,0 +1,138 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 6 +#define ALPHA 1.0 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void ref_axpy(int64_t n, + T * alpha, + T *X, + T *Y + ) + +{ + obj_t obj_x, obj_y, obj_alpha; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + bli_axpyv( &obj_alpha, + &obj_x, + &obj_y + ); + +} +template< typename T > +void test_axpy( ) +{ + T *X, *Y,*Y_ref; + T alpha = ALPHA; + int n; + int incx, incy; + + n = N; + + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(Y, Y_ref , n ,1); + +#ifdef PRINT + printvector(X, n,(char *) "X"); + printvector(Y, n, (char *) "Y"); +#endif + blis::axpy( + n, + alpha, + X, + incx, + Y, + incy + ); + +#ifdef PRINT + printvector(Y, n,(char *) "Y output"); +#endif + ref_axpy(n , &alpha , X, Y_ref ); + +#ifdef PRINT + printvector(Y_ref, n, (char *) "Y ref output"); +#endif + if(computeErrorV(incy, incy , n, Y, Y_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( X ); + delete[]( Y ); + delete[]( Y_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_axpy( ); + test_axpy( ); + test_axpy>( ); + test_axpy>( ); + return 0; + +} diff --git a/vendor/testcpp/test_copy.cc b/vendor/testcpp/test_copy.cc new file mode 100644 index 0000000000..a1042d1c9b --- /dev/null +++ b/vendor/testcpp/test_copy.cc @@ -0,0 +1,132 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void ref_copy(int64_t n, + T *X, + T *Y + ) + +{ + obj_t obj_x, obj_y; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + bli_copyv( &obj_x, + &obj_y + ); + +} +template< typename T > +void test_copy( ) +{ + T *X, *X_ref, *Y,*Y_ref; + int n; + int incx, incy; + + n = N; + + incx = 1; + incy = 1; + + Y = new T[n]; + Y_ref = new T[n]; + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + copy_buffer(X, X_ref , n ,1); + +#ifdef PRINT + printvector(X, n,(char *) "X"); +#endif + blis::copy( + n, + X, + incx, + Y, + incy + ); + +#ifdef PRINT + printvector(Y, n,(char *) "Y output"); +#endif + ref_copy(n , X_ref, Y_ref ); + +#ifdef PRINT + printvector(Y_ref, n,(char *) "Y ref output"); +#endif + if(computeErrorV(incy , incy , n, Y, Y_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + delete[]( X ); + delete[]( X_ref ); + delete[]( Y ); + delete[]( Y_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_copy( ); + test_copy( ); + test_copy>(); + test_copy>(); + return 0; + +} diff --git a/vendor/testcpp/test_dot.cc b/vendor/testcpp/test_dot.cc new file mode 100644 index 0000000000..553287784a --- /dev/null +++ b/vendor/testcpp/test_dot.cc @@ -0,0 +1,131 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T, typename TR> +void ref_dot(int64_t n, + T *X, + T *Y, + TR *res_ref + ) + +{ + obj_t obj_x; + obj_t obj_y; + obj_t obj_res; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + bli_obj_create_with_attached_buffer( dt, 1, 1, res_ref, 1, 1,&obj_res ); + + bli_dotv(&obj_x, + &obj_y, + &obj_res ); + +} +template< typename T, typename TR> +void test_dot() +{ + T *X, *Y; + int n; + int incx, incy; + TR res = 0, res_ref = 0; + + n = N; + + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + +#ifdef PRINT + printvector(X, n, (char *)"X"); + printvector(Y, n, (char *)"Y"); +#endif + res = blis::dot( + n, + X, + incx, + Y, + incy + ); + +#ifdef PRINT + printf("Dot product = %E \n", res); + +#endif + ref_dot(n, X, Y , &res_ref ); + +#ifdef PRINT + printf("Dot product ref_dot %E \n", res_ref); + +#endif + if(res != res_ref ) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( X ); + delete[]( Y ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_dot( ); + //test_dot( ); + test_dot( ); + return 0; + +} diff --git a/vendor/testcpp/test_dotc.cc b/vendor/testcpp/test_dotc.cc new file mode 100644 index 0000000000..88ffe19c4d --- /dev/null +++ b/vendor/testcpp/test_dotc.cc @@ -0,0 +1,127 @@ +/* + + BLISPP + C++ test driver for BLIS CPP dotc routine and reference blis dotc routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 16 + +template< typename T > +void ref_dotc(int64_t n, + T *X, + T *Y, + T *res_ref + ) +{ + obj_t obj_x; + obj_t obj_y; + obj_t obj_res; + num_t dt; + + if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n, &obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n, &obj_y ); + bli_obj_set_conj(BLIS_CONJUGATE,&obj_x); + bli_obj_create_with_attached_buffer( dt, 1, 1, res_ref, 1, 1,&obj_res ); + + bli_dotv(&obj_x, + &obj_y, + &obj_res ); + +} + +template< typename T > +void test_dotc() +{ + T *X, *Y; + int n; + int incx, incy; + T res = 0, res_ref = 0; + + n = N; + + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + +#ifdef PRINT + printvector(X, n,(char *) "X"); + printvector(Y, n,(char *) "Y"); +#endif + + res = blis::dotc( + n, + X, + incx, + Y, + incy + ); + +#ifdef PRINT + cout<< "Dot product \n" << res << "\n"; +#endif + ref_dotc(n, X, Y , &res_ref ); + +#ifdef PRINT + cout<< "Dot product ref\n" << res_ref << "\n";; +#endif + + if(res != res_ref ) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + delete[]( X ); + delete[]( Y ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_dotc>( ); + test_dotc>( ); + return 0; + +} diff --git a/vendor/testcpp/test_gbmv.cc b/vendor/testcpp/test_gbmv.cc new file mode 100644 index 0000000000..6d64f42ee3 --- /dev/null +++ b/vendor/testcpp/test_gbmv.cc @@ -0,0 +1,109 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA -1.0 +#define BETA -1.0 +#define M 3 +#define N 4 + +template< typename T > +void test_gbmv( ) +{ +// int i, j, p; + T alpha, beta; + int m,n; + int KL = 1; + int KU = 1; + int lda = 4; + T A[] = { 0.423f, -0.143f, -0.182f, -0.076f, -0.855f, 0.599f, 0.389f, -0.473f, 0.493f, -0.902f, -0.889f, -0.256f, 0.112f, 0.128f, -0.277f, -0.777f }; + T X[] = { 0.488f, 0.029f, -0.633f, 0.84f }; + int incX = -1; + T Y[] = { 0.874f, 0.322f, -0.477f }; + int incY = -1; + T Y_ref[] = { -0.656261f, 0.19575f, 0.055905f }; + alpha = ALPHA; + beta = BETA; + m = M; + n = N; + + +#ifdef PRINT + printmatrix(A, lda ,m,n,(char *) "A"); + printvector(Y, m, (char *)"m"); +#endif + blis::gbmv( + CblasColMajor, + CblasNoTrans, + m, + n,KL,KU, + alpha, + A, + lda, + X, + incX, + beta, + Y, + incY + ); + +#ifdef PRINT + printvector(Y, m,(char *)"Y blis:gbmv"); + printvector(Y_ref, m, (char *) "Y_ref blis:gbmv" ); + +#endif + + if(computeErrorV(incY,incY, m, Y, Y_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_gbmv( ); + test_gbmv( ); + test_gbmv>( ); + test_gbmv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_gemm.cc b/vendor/testcpp/test_gemm.cc new file mode 100644 index 0000000000..2fe6e55a7c --- /dev/null +++ b/vendor/testcpp/test_gemm.cc @@ -0,0 +1,163 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define M 5 +#define N 6 +#define K 4 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_gemm(int64_t m, int64_t n, int64_t k, + T * alpha, + T *A, + T *B, + T * beta, + T *C ) + +{ + obj_t obj_a, obj_b, obj_c; + obj_t obj_alpha, obj_beta; + num_t dt; + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, m, k, A, 1,m,&obj_a ); + bli_obj_create_with_attached_buffer( dt, k, n, B,1,k,&obj_b ); + bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c ); + + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_b ); + bli_gemm( &obj_alpha, + &obj_a, + &obj_b, + &obj_beta, + &obj_c ); + +} +template< typename T > +void test_gemm( ) +{ + T *A, *B, *C, *C_ref; + T alpha, beta; + int m,n,k; + int lda, ldb, ldc, ldc_ref; + + alpha = ALPHA; + beta = BETA; + m = M; + k = K; + n = N; + + lda = m; + ldb = k; + ldc = m; + ldc_ref = m; + srand (time(NULL)); + allocate_init_buffer(A , m , k); + allocate_init_buffer(B , k , n); + allocate_init_buffer(C , m , n); + copy_buffer(C, C_ref , m ,n); + +#ifdef PRINT + printmatrix(A, lda ,m,k , (char *)"A"); + printmatrix(B, ldb ,k,n, (char *)"B"); + printmatrix(C, ldc ,m,n, (char *)"C"); +#endif + blis::gemm( + CblasColMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc + ); + +#ifdef PRINT + printmatrix(C,ldc ,m,n , (char *)"C output"); +#endif + ref_gemm(m, n, k, &alpha, A, B, &beta, C_ref); + +#ifdef PRINT + printmatrix(C_ref, ldc_ref ,m,n, (char *)"C ref output"); +#endif + if(computeErrorM(ldc, ldc_ref, m, n, C, C_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); + + + + delete[]( A ); + delete[]( B ); + delete[]( C ); + delete[]( C_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_gemm( ); + test_gemm( ); + test_gemm>( ); + test_gemm>( ); + return 0; + +} diff --git a/vendor/testcpp/test_gemm.hh b/vendor/testcpp/test_gemm.hh new file mode 100644 index 0000000000..876ac16658 --- /dev/null +++ b/vendor/testcpp/test_gemm.hh @@ -0,0 +1,110 @@ +/* + * -------------------------------------------------------------------------- + * BLISLAB + * -------------------------------------------------------------------------- + * Copyright (C) 2016, The University of Texas at Austin + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * - Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * - Neither the name of The University of Texas nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * test_gemm.hh + * + * + * Purpose: + * this header file contains all function prototypes. + * + * Todo: + * + * + * Modification: + * + * + * */ + + +#ifndef TEST_GEMM_HH +#define TEST_GEMM_HH + +#include + +#include +#include + +using namespace std; +#define min( i, j ) ( (i)<(j) ? (i): (j) ) + +#define A( i, j ) A[ (j)*lda + (i) ] +#define B( i, j ) B[ (j)*ldb + (i) ] +#define C( i, j ) C[ (j)*ldc + (i) ] +#define C_ref( i, j ) C_ref[ (j)*ldc_ref + (i) ] + +template< typename T > +int computeError( + int ldc, + int ldc_ref, + int m, + int n, + T *C, + T *C_ref + ) +{ + int i, j; + int ret = 0; + for ( i = 0; i < m; i ++ ) { + for ( j = 0; j < n; j ++ ) { + if ( C( i, j ) != C_ref( i, j ) ) { + printf( "C[ %d ][ %d ] != C_ref, %E, %E\n", i, j, C( i, j ), C_ref( i, j ) ); + ret = 1; + break; + } + } + } + return ret; + +} + +/* + * + * + */ +template +void bl_dgemm_printmatrix( + T *A, + int lda, + int m, + int n + ) +{ + int i, j; + for ( i = 0; i < m; i ++ ) { + for ( j = 0; j < n; j ++ ) { + cout<< A[j * lda + i]<<" "; + } + printf("\n"); + } + printf("\n"); +} + +#endif diff --git a/vendor/testcpp/test_gemv.cc b/vendor/testcpp/test_gemv.cc new file mode 100644 index 0000000000..ca36a61d29 --- /dev/null +++ b/vendor/testcpp/test_gemv.cc @@ -0,0 +1,162 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define M 5 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_gemv(int64_t m, int64_t n, + T * alpha, + T *A, + T *X, + T * beta, + T *Y ) + +{ + obj_t obj_a, obj_x, obj_y; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, m, n, A, 1,m,&obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, m, 1, Y, 1,m,&obj_y ); + + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_x); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_y); + bli_gemv( &obj_alpha, + &obj_a, + &obj_x, + &obj_beta, + &obj_y ); + +} +template< typename T > +void test_gemv( ) +{ + T *A, *Y, *Y_ref, *X; + T alpha, beta; + int m,n; + int lda, incx, incy, incy_ref; + + alpha = ALPHA; + beta = BETA; + m = M; + n = N; + + lda = m; + incx = 1; + incy = 1; + incy_ref = 1; + + srand (time(NULL)); + allocate_init_buffer(A , m , n); + allocate_init_buffer(X , m , 1); + allocate_init_buffer(Y , m , 1); + copy_buffer(Y, Y_ref , m ,1); + +#ifdef PRINT + printmatrix(A, lda ,m,n,(char *) "A"); + printvector(X, m,(char *) "X"); + printvector(Y, m, (char *)"Y"); +#endif + blis::gemv( + CblasColMajor, + CblasNoTrans, + m, + n, + alpha, + A, + lda, + X, + incx, + beta, + Y, + incy + ); + +#ifdef PRINT + printvector(Y, m, (char *)"Y output"); +#endif + ref_gemv(m, n, &alpha, A, X, &beta, Y_ref); + +#ifdef PRINT + printvector(Y_ref, m, (char *) "Y_Ref output"); +#endif + if(computeErrorV(incy,incy_ref, m , Y, Y_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); + + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( Y_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_gemv( ); + test_gemv( ); + test_gemv>( ); + test_gemv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_ger.cc b/vendor/testcpp/test_ger.cc new file mode 100644 index 0000000000..15b018ce60 --- /dev/null +++ b/vendor/testcpp/test_ger.cc @@ -0,0 +1,150 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define M 5 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_ger(int64_t m, int64_t n, + T * alpha, + T *X, + T *Y, + T *A ) + +{ + obj_t obj_a; + obj_t obj_x; + obj_t obj_y; + obj_t obj_alpha; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + + bli_obj_create_with_attached_buffer( dt, m, n, A, 1, m, &obj_a ); + bli_obj_create_with_attached_buffer( dt, m, 1, X, 1, m,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + //bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); + //bli_obj_set_uplo( BLIS_LOWER, &obj_a); + bli_ger( &obj_alpha, + &obj_x, + &obj_y, + &obj_a ); + +} +template< typename T > +void test_ger( ) +{ + T *A, *X, *Y, *A_ref; + T alpha; + int m,n; + int lda, incx, incy, lda_ref; + + alpha = ALPHA; + m = M; + n = N; + + lda = m; + lda_ref = m; + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(A , m , n); + allocate_init_buffer(X , m , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(A, A_ref , m ,n); + +#ifdef PRINT + printmatrix(A, lda ,m,n,(char *) "A"); + printvector(X, m,(char *) "X"); + printvector(Y, n,(char *) "Y"); +#endif + blis::ger( + CblasColMajor, + m, + n, + alpha, + X, + incx, + Y, + incy, + A, + lda + ); + +#ifdef PRINT + printmatrix(A, lda , m ,n ,(char *) "A output"); +#endif + ref_ger(m, n, &alpha, X, Y, A_ref); + +#ifdef PRINT + printmatrix(A_ref, lda ,m,n, (char *)"A_ref output"); +#endif + if(computeErrorM(lda, lda_ref, m, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_ger( ); + test_ger( ); + return 0; + +} diff --git a/vendor/testcpp/test_gerc.cc b/vendor/testcpp/test_gerc.cc new file mode 100644 index 0000000000..332405b7c1 --- /dev/null +++ b/vendor/testcpp/test_gerc.cc @@ -0,0 +1,174 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define M 5 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_gerc(int64_t m, int64_t n, + T * alpha, + T *X, + T *Y, + T *A ) + +{ +obj_t obj_a; +obj_t obj_x; +obj_t obj_y; +obj_t obj_alpha; +num_t dt; + +if(is_same::value) + dt = BLIS_FLOAT; +else if(is_same::value) + dt = BLIS_DOUBLE; +else if(is_same>::value) + dt = BLIS_SCOMPLEX; +else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + +if(dt == BLIS_FLOAT){ + bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); + } +else if(dt == BLIS_DOUBLE){ + bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); + } + +if(dt == BLIS_SCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_SCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); + } +else if(dt == BLIS_DCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_DCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); + } + +bli_obj_create_with_attached_buffer( dt, m, n, A, 1, m, &obj_a ); +bli_obj_create_with_attached_buffer( dt, m, 1, X, 1, m,&obj_x ); +bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + bli_obj_set_conj(BLIS_CONJUGATE,&obj_y); +bli_ger( &obj_alpha, + &obj_x, + &obj_y, + &obj_a ); +} + + +template< typename T > +void test_gerc( ) +{ + T *A, *X, *Y, *A_ref; + T alpha; + int m,n; + int lda, incx, incy, lda_ref; + + alpha = ALPHA; + m = M; + n = N; + + lda = m; + lda_ref = m; + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(A , m , n); + allocate_init_buffer(X , m , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(A, A_ref , m ,n); + + + +#ifdef PRINT + printmatrix(A, lda ,m,n,(char *)"A"); + printvector(X, m, (char *)"X"); + +#endif + blis::gerc( + CblasColMajor, + m, + n, + alpha, + X, + incx, + Y, + incy, + A, + lda + ); + +#ifdef PRINT + printmatrix (A, lda ,m , n,(char *)"A blis::gerc\n"); + +#endif + ref_gerc(m, n, &alpha, X, Y, A_ref); + +#ifdef PRINT + printmatrix(A_ref, lda_ref, m, n, (char *)"A_ref output\n"); + + + + +#endif + if(computeErrorM(lda, lda_ref, m, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_gerc>( ); + test_gerc>( ); + return 0; + +} diff --git a/vendor/testcpp/test_geru.cc b/vendor/testcpp/test_geru.cc new file mode 100644 index 0000000000..03e3e6a271 --- /dev/null +++ b/vendor/testcpp/test_geru.cc @@ -0,0 +1,169 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define M 5 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_geru(int64_t m, int64_t n, + T * alpha, + T *X, + T *Y, + T *A ) + +{ +obj_t obj_a; +obj_t obj_x; +obj_t obj_y; +obj_t obj_alpha; +num_t dt; + +if(is_same::value) + dt = BLIS_FLOAT; +else if(is_same::value) + dt = BLIS_DOUBLE; +else if(is_same>::value) + dt = BLIS_SCOMPLEX; +else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + +if(dt == BLIS_FLOAT){ + bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); + } +else if(dt == BLIS_DOUBLE){ + bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); + } + +if(dt == BLIS_SCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_SCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); + } +else if(dt == BLIS_DCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_DCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); + } + +bli_obj_create_with_attached_buffer( dt, m, n, A, 1, m, &obj_a ); +bli_obj_create_with_attached_buffer( dt, m, 1, X, 1, m,&obj_x ); +bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + +bli_ger( &obj_alpha, + &obj_x, + &obj_y, + &obj_a ); +} + + +template< typename T > +void test_geru( ) +{ + T *A, *X, *Y, *A_ref; + T alpha; + int m,n; + int lda, incx, incy, lda_ref; + + alpha = ALPHA; + m = M; + n = N; + + lda = m; + lda_ref = m; + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(A , m , n); + allocate_init_buffer(X , m , 1); + allocate_init_buffer(Y , n , 1); +copy_buffer(A, A_ref , m ,n); + + +#ifdef PRINT + printmatrix(A, lda ,m,n,(char *)"A"); + printvector(X, m,(char *) "X"); +#endif + blis::geru( + CblasColMajor, + m, + n, + alpha, + X, + incx, + Y, + incy, + A, + lda + ); + +#ifdef PRINT + printmatrix (A, lda ,m,n,(char *)"A output"); + printvector (X, m,(char *) "X"); + +#endif + ref_geru(m, n, &alpha, X, Y, A_ref); + +#ifdef PRINT + printmatrix(A_ref, lda_ref, m,n,(char *)"A_ref output" ); + +#endif + if(computeErrorM(lda, lda_ref, m, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_geru>( ); + test_geru>( ); + return 0; + +} diff --git a/vendor/testcpp/test_hemm.cc b/vendor/testcpp/test_hemm.cc new file mode 100644 index 0000000000..8b88bcad35 --- /dev/null +++ b/vendor/testcpp/test_hemm.cc @@ -0,0 +1,164 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define M 5 +#define N 5 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_hemm(int64_t m, int64_t n, + T * alpha, + T *A, + T *B, + T * beta, + T *C ) + +{ + obj_t obj_a, obj_b, obj_c; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); + bli_obj_create_with_attached_buffer( dt, m, n, B, 1,n,&obj_b ); + bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c ); + + bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a ); + bli_mkherm(&obj_a); + bli_mktrim(&obj_a); + bli_hemm( BLIS_LEFT, + &obj_alpha, + &obj_a, + &obj_b, + &obj_beta, + &obj_c ); + +} +template< typename T > +void test_hemm( ) +{ + T *A, *B, *C, *C_ref; + T alpha, beta; + int m,n; + int lda, ldb, ldc, ldc_ref; + + alpha = ALPHA; + beta = BETA; + m = M; + n = N; + + lda = m; + ldb = n; + ldc = m; + ldc_ref = m; + + srand48 (time(NULL)); + srand (time(NULL)); + allocate_init_buffer(A , m , m); + allocate_init_buffer(B , m , n); + allocate_init_buffer(C , m , n); + copy_buffer(C, C_ref , m ,n); + +#ifdef PRINT + printmatrix(A, lda ,m,m,(char *) "A"); + printmatrix(B, ldb ,m,n,(char *) "B"); + printmatrix(C, ldc ,m,n,(char *) "C"); +#endif + blis::hemm( + CblasColMajor, + CblasLeft, + CblasLower, + m, + n, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc + ); + +#ifdef PRINT + printmatrix(C, ldc ,m,n,(char *) "C output"); +#endif + ref_hemm(m, n, &alpha, A, B, &beta, C_ref); + +#ifdef PRINT + printmatrix(C_ref, ldc_ref ,m,n,(char *) "C ref output"); +#endif + if(computeErrorM(ldc, ldc_ref, m, n, C, C_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); + + + + delete[]( A ); + delete[]( B ); + delete[]( C ); + delete[]( C_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_hemm>( ); + test_hemm>( ); + return 0; + +} diff --git a/vendor/testcpp/test_hemv.cc b/vendor/testcpp/test_hemv.cc new file mode 100644 index 0000000000..463fdf557f --- /dev/null +++ b/vendor/testcpp/test_hemv.cc @@ -0,0 +1,157 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_hemv(int64_t n, + T * alpha, + T *A, + T *X, + T * beta, + T *Y ) + +{ + obj_t obj_a, obj_x, obj_y; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, n, n, A, 1,n,&obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1,n,&obj_y ); + + bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a ); + + bli_hemv( &obj_alpha, + &obj_a, + &obj_x, + &obj_beta, + &obj_y ); + +} +template< typename T > +void test_hemv( ) +{ + T *A, *Y, *Y_ref, *X; + T alpha, beta; + int n; + int lda, incx, incy, incy_ref; + + alpha = ALPHA; + beta = BETA; + n = N; + + lda = n; + incx = 1; + incy = 1; + incy_ref = 1; + + srand (time(NULL)); + allocate_init_buffer(A , n , n); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(Y, Y_ref , n ,1); + +#ifdef PRINT + printmatrix(A, lda ,n,n, (char *)"A"); + printvector(X, n, (char *)"X"); + printvector(Y, n, (char *)"Y"); +#endif + blis::hemv( + CblasColMajor, + CblasLower, + n, + alpha, + A, + lda, + X, + incx, + beta, + Y, + incy + ); + +#ifdef PRINT + printvector(Y, n, (char *)"Y output"); +#endif + ref_hemv(n, &alpha, A, X, &beta, Y_ref); + +#ifdef PRINT + printvector(Y_ref, n,(char *) "Y_ref output"); +#endif + if(computeErrorV(incy,incy_ref, n, Y, Y_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); + + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( Y_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_hemv>( ); + test_hemv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_her.cc b/vendor/testcpp/test_her.cc new file mode 100644 index 0000000000..687d1e90d8 --- /dev/null +++ b/vendor/testcpp/test_her.cc @@ -0,0 +1,141 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_her(int64_t n, + real_type * alpha, + T *X, + T *A ) + +{ + obj_t obj_a; + obj_t obj_x; + obj_t obj_alpha; + num_t dt; + + if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + if(dt == BLIS_SCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); + } + else if(dt == BLIS_DCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); + } + + bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + + bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a); + bli_her( &obj_alpha, + &obj_x, + &obj_a ); + +} +template< typename T > +void test_her( ) +{ + T *A, *X, *A_ref; + real_type alpha; + int n; + int lda, incx, lda_ref; + + alpha = ALPHA; + n = N; + + lda = n; + lda_ref = n; + incx = 1; + srand (time(NULL)); + allocate_init_buffer(A , n , n); + allocate_init_buffer(X , n , 1); + copy_buffer(A, A_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,n,(char *) "A"); + printvector(X, n,(char *) "X"); +#endif + blis::her( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incx, + A, + lda + ); + +#ifdef PRINT + printmatrix(A, lda ,n,n, (char *)"A output"); +#endif + ref_her(n, &alpha, X, A_ref); +#ifdef PRINT + printmatrix(A_ref, lda_ref, n,n ,(char *) "A refoutput"); +#endif + if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_her>( ); + test_her>( ); + return 0; + +} diff --git a/vendor/testcpp/test_her2.cc b/vendor/testcpp/test_her2.cc new file mode 100644 index 0000000000..2f3ca253ac --- /dev/null +++ b/vendor/testcpp/test_her2.cc @@ -0,0 +1,147 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_her2(int64_t n, + T * alpha, + T *X, + T *Y, + T *A ) + +{ + obj_t obj_a; + obj_t obj_x, obj_y; + obj_t obj_alpha; + num_t dt; + + if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer(dt, 1, 1, alpha, 1,1,&obj_alpha ); + + bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a); + bli_her2( &obj_alpha, + &obj_x, + &obj_y, + &obj_a ); + +} +template< typename T > +void test_her2( ) +{ + T *A, *X, *Y, *A_ref; + T alpha; + int n; + int lda, incx, incy, lda_ref; + + alpha = ALPHA; + n = N; + + lda = n; + lda_ref = n; + incx = 1; + incy = 1; + + + srand (time(NULL)); + allocate_init_buffer(A , n , n); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(A, A_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,n,(char *) "A"); + printvector(X, n,(char *) "X"); + printvector(Y, n, (char *)"Y"); +#endif + blis::her2( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incx, + Y, + incy, + A, + lda + ); + +#ifdef PRINT + printmatrix(A, lda , n , n,(char *) "A output"); +#endif + ref_her2(n, &alpha, X, Y, A_ref); +#ifdef PRINT + printmatrix(A_ref, lda , n, n, (char *)"A_ref output"); +#endif + if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_her2>( ); + test_her2>( ); + return 0; + +} diff --git a/vendor/testcpp/test_herk.cc b/vendor/testcpp/test_herk.cc new file mode 100644 index 0000000000..3febf3e6f1 --- /dev/null +++ b/vendor/testcpp/test_herk.cc @@ -0,0 +1,155 @@ +/* + + BLISPP + C++ test driver for BLIS CPP herk routine and reference blis herk routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define N 6 +#define K 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_herk(int64_t n, int64_t k, + real_type * alpha, + T *A, + real_type * beta, + T *C ) + +{ + obj_t obj_a,obj_c; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + if(dt == BLIS_SCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, beta, 1,1,&obj_beta ); + } + else if(dt == BLIS_DCOMPLEX){ + bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, beta, 1,1,&obj_beta ); + } + + bli_obj_create_with_attached_buffer( dt, n, k, A, 1,n,&obj_a ); + bli_obj_create_with_attached_buffer( dt, n, n, C, 1,n,&obj_c ); + + bli_obj_set_struc( BLIS_HERMITIAN, &obj_c ); + bli_obj_set_uplo( BLIS_LOWER, &obj_c ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_c ); + bli_herk( &obj_alpha, + &obj_a, + &obj_beta, + &obj_c ); + +} +template< typename T > +void test_herk( ) +{ + T *A, *C, *C_ref; + real_type alpha; + real_type beta; + int n,k; + int lda, ldc, ldc_ref; + + alpha = ALPHA; + beta = BETA; + k = K; + n = N; + + + lda = k; + ldc = n; + ldc_ref = n; + srand (time(NULL)); + allocate_init_buffer(A , n , k); + allocate_init_buffer(C , n , n); + copy_buffer(C, C_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,k, (char *)"A"); + printmatrix(C, ldc ,n,n, (char *)"C"); +#endif + blis::herk( + CblasColMajor, + CblasLower, + CblasNoTrans, + n, + k, + alpha, + A, + lda, + beta, + C, + ldc + ); + +#ifdef PRINT + printmatrix(C, ldc ,n,n, (char *)"C output"); +#endif + ref_herk(n, k, &alpha, A, &beta, C_ref); + +#ifdef PRINT + printmatrix(C_ref, ldc_ref ,n,n, (char *)"C ref output"); +#endif + if(computeErrorM(ldc, ldc_ref, n, n, C, C_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + + delete[]( A ); + delete[]( C ); + delete[]( C_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_herk>( ); + test_herk>( ); + return 0; + +} diff --git a/vendor/testcpp/test_hpr.cc b/vendor/testcpp/test_hpr.cc new file mode 100644 index 0000000000..dfc7bdd4a9 --- /dev/null +++ b/vendor/testcpp/test_hpr.cc @@ -0,0 +1,112 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 2 + +/* + * Test application assumes matrices to be column major, non-transposed + */ + +template< typename T > +void test_hpr( ) +{ +int n; +real_type alpha; +int incX = -1; + +alpha = 1.0; +n = N; + + +T A[4]; + A[0] = { 0.265, 0.362}; + A[1] = {-0.855, 0.035}; + A[2] = {0.136, 0.133 }; + A[3] = { 0.00, 0.00}; + +T X[2]; + X[0] = { -0.278, -0.686}; + X[1] = {-0.736, -0.918 }; + +T A_ref[4]; + A_ref[0] = { 1.64942, 0.0}; + A_ref[1] = {-0.020644, 0.284692}; + A_ref[2] = {0.68388, 0.0 }; + A_ref[3] = {0.00, 0.00 }; + + + +#ifdef PRINT + printmatrix(A, n,n, n,(char *) "A"); + printvector(X, n, (char *)"X"); +#endif + blis::hpr( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incX, + A + ); + +#ifdef PRINT + printmatrix(A, n , n, n,(char *)"A blis:hpr\n"); + + printmatrix(A_ref, n, n, n,(char *)"A_ref output\n"); +#endif + + if(computeErrorM(n, n, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_hpr>( ); + test_hpr>( ); + return 0; + +} diff --git a/vendor/testcpp/test_hpr2.cc b/vendor/testcpp/test_hpr2.cc new file mode 100644 index 0000000000..1b8b9b2b4f --- /dev/null +++ b/vendor/testcpp/test_hpr2.cc @@ -0,0 +1,93 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +#define N 1 + +/* + * Test application assumes matrices to be column major, non-transposed + */ + +template< typename T > +void test_hpr2( ) +{ +int n; +int incX = -1; +int incY = -1; + n = N; + +T alpha = {-0.3, 0.1}; + +T A[1]; + A[0] = { 0.772, 0.997 }; +T X[1]; + X[0] = { -0.173, -0.839 }; +T Y[1]; + Y[0] = { 0.941, -0.422 }; +T A_ref[1]; + A_ref[0] = { 0.829742, 0.0 }; + + blis::hpr2( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incX, + Y, + incY, + A + ); + + + if(computeErrorM(1, 1, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_hpr2>( ); + printf("**************\n"); + test_hpr2>( ); + return 0; + +} diff --git a/vendor/testcpp/test_nrm2.cc b/vendor/testcpp/test_nrm2.cc new file mode 100644 index 0000000000..d29ec77788 --- /dev/null +++ b/vendor/testcpp/test_nrm2.cc @@ -0,0 +1,100 @@ +/* + + BLISPP + C++ test driver for BLIS CPP nrm2 routine and reference blis nrm2 routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 2 +#define ALPHA 0.5 + +#define TOLERANCE 0.0000001 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void test_nrm2() +{ + + T X[N]; + T nrm2, nrm2_ref; + int n; + int incx; + + n = N; + incx = 1; + + if(is_same::value) + { + X[0] = 0.14f; + X[1] = -0.632f; + nrm2_ref = 0.647320631527f; + } + else if(is_same::value) + { + X[0] = 0.696; + X[1] = -0.804; + nrm2_ref = 1.06340584915; + } + +#ifdef PRINT + printvector(X, n,(char *) "Vector X after blis::nrm2"); +#endif + nrm2 = blis::nrm2( + n, + X, + incx + ); +#ifdef PRINT + printf("Norm of a Vector %E \n", nrm2); + printf("Ref Norm of a Vector %E \n", nrm2_ref); +#endif + + if (fabs(nrm2 - nrm2_ref) > TOLERANCE) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_nrm2( ); + test_nrm2( ); + return 0; + +} diff --git a/vendor/testcpp/test_rot.cc b/vendor/testcpp/test_rot.cc new file mode 100644 index 0000000000..a2e3fb7086 --- /dev/null +++ b/vendor/testcpp/test_rot.cc @@ -0,0 +1,102 @@ +/* + + BLISPP + C++ test driver for BLIS CPP rot routine and reference blis rot routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 1 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void test_rot() +{ + + T c, s; + T X[N], X_ref[N]; + T Y[N], Y_ref[N]; + int n; + int incx, incy; + + n = N; + incx = 1; + incy = 1; + if(is_same::value){ + c = -1.0f; + s = 0.0f; + X[0] = { -0.314f }; + Y[0] = { -0.406f }; + X_ref[0] = { 0.314f }; + Y_ref[0] = { 0.406f }; + }else{ + c = -1; + s = 0; + X[0] = { -0.176 }; + Y[0] = { -0.165 }; + X_ref[0] = { 0.176 }; + Y_ref[0] = { 0.165 }; + } + +#ifdef PRINT + printvector(X, n, (char *)"Before blis::rot\nVector X"); + printvector(Y, n, (char *)"Vector Y"); +#endif + blis::rot( N, X, incx, Y, incy, c, s); +#ifdef PRINT + printvector(X, n, (char *)"After blis::rot\nVector X"); + printvector(Y, n, (char *) "Vector Y"); + printvector(X, n, (char *) "Expected Output from blis::rot\nVector X"); + printvector(Y, n, (char *)"Vector Y"); +#endif + + if((computeErrorV(incx, incx , n, X, X_ref )==1) || (computeErrorV(incy, incy , n, Y, Y_ref )==1)) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_rot( ); + test_rot( ); + return 0; + +} diff --git a/vendor/testcpp/test_rotg.cc b/vendor/testcpp/test_rotg.cc new file mode 100644 index 0000000000..e11571ae3c --- /dev/null +++ b/vendor/testcpp/test_rotg.cc @@ -0,0 +1,108 @@ +/* + + BLISPP + C++ test driver for BLIS CPP rotg routine and reference blis rotg routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void test_rotg() +{ + + T a, b, c, s; + T a_ref, b_ref, c_ref, s_ref; + + if(is_same::value) + { + a = 1.0f; + b = 1.0f; + a_ref = 1.41421356237f; + b_ref = 1.41421356237f; + c_ref = 0.707106781187f; + s_ref = 0.707106781187f; + }else{ + a = 1; + b = 0; + a_ref = 1; + b_ref = 0; + c_ref = 1; + s_ref = 0; + } + +#ifdef PRINT + cout<< "Before blis::rotg \na Value : " << a << "\n" ; + cout<< "b Value : " << b << "\n" ; +#endif + blis::rotg( + &a, + &b, + &c, + &s + ); + +#ifdef PRINT + cout<< "After blis::rotg \na Value : " << a << "\n" ; + cout<< "b Value : " << b << "\n" ; + cout<< "c Value : " << c << "\n" ; + cout<< "s Value : " << s << "\n" ; +#endif + +#ifdef PRINT + cout<< "Expected Output\na Value : " << a_ref << "\n" ; + cout<< "b Value : " << b_ref << "\n" ; + cout<< "c Value : " << c_ref << "\n" ; + cout<< "s Value : " << s_ref << "\n" ; +#endif + if( (a != a_ref ) || (b != b_ref ) || (c != c_ref ) || (s != s_ref )) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_rotg( ); + test_rotg( ); + return 0; + +} diff --git a/vendor/testcpp/test_rotm.cc b/vendor/testcpp/test_rotm.cc new file mode 100644 index 0000000000..aad4504b83 --- /dev/null +++ b/vendor/testcpp/test_rotm.cc @@ -0,0 +1,106 @@ +/* + + BLISPP + C++ test driver for BLIS CPP rotm routine and reference blis rotm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 1 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void test_rotm() +{ + + T X[N], X_ref[N]; + T Y[N], Y_ref[N]; + int n; + int incx, incy; + const T P[5] = { -1.0f, -4.44982e+03f, -15.5826f, 7.091334e+04f, 2.95912e+04f }; + const T P_double[5] = { 1.0, -1.244580625511e+03, 1.11154682624, + 2.269384716089e-05, -0.0143785338883 }; + n = N; + incx = 1; + incy = 1; + if(is_same::value) + { + X[0] = { -0.034f }; + Y[0] = { -0.56f }; + X_ref[0] = { -3.956017e+04f }; + Y_ref[0] = { -1.657054e+04f }; + }else{ + X[0] = { 0.84 }; + Y[0] = { -0.711 }; + X_ref[0] = { -1.046158725429e+03 }; + Y_ref[0] = { -0.829776862405 }; + } + +#ifdef PRINT + printvector(X, n, (char *)"Before blis::rot\nVector X"); + printvector(Y, n, (char *)"Vector Y"); +#endif + if(is_same::value) + { + blis::rotm( N, X, incx, Y, incy, P); + }else{ + blis::rotm( N, X, incx, Y, incy, P_double); + } +#ifdef PRINT + printvector(X, n, (char *)"After blis::rot\nVector X"); + printvector(Y, n, (char *)"Vector Y"); + printvector(X, n, (char *)"Expected Output from blis::rot\nVector X"); + printvector(Y, n, (char *)"Vector Y"); +#endif + + if((computeErrorV(incx, incx , n, X, X_ref )==1) + || (computeErrorV(incy, incy , n, Y, Y_ref )==1)) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_rotm( ); + test_rotm( ); + return 0; + +} diff --git a/vendor/testcpp/test_rotmg.cc b/vendor/testcpp/test_rotmg.cc new file mode 100644 index 0000000000..b2325bb241 --- /dev/null +++ b/vendor/testcpp/test_rotmg.cc @@ -0,0 +1,137 @@ +/* + + BLISPP + C++ test driver for BLIS CPP rotmg routine and reference blis rotmg routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void test_rotmg() +{ + T d1, d2, b1, b2; + T d1_ref, d2_ref, b1_ref; + T h[5] = { -999.0f, -999.1f, -999.2f, -999.3f, -999.4f }; + T h_ref[5] = {-1.0f, 0.0f, 0.0f, 0.0f,0.0f}; + T h_double[5] = { -999.0, -999.1, -999.2, -999.3, -999.4 }; + T h_ref_double[5] = { 1, 0, 0, 0}; + + if(is_same::value) + { + d1 = -1630.28519312f; + d2 = 44320.1964703f; + b1 = 1274.7681352f; + b2 = 0.983006912864f; + d1_ref= 0.0f; + d2_ref= 0.0f; + b1_ref= 0.0f; + }else{ + d1 = -49.1978123005; + d2 = 0.228703451277; + b1 = 1.8901039144; + b2 = 7081.47754386; + d1_ref= 0; + d2_ref= 0; + b1_ref= 0; + } + +#ifdef PRINT + cout<< "Before blis::rotmg \nd1 Value : " << d1 << "\n" ; + cout<< "d2 Value : " << d2 << "\n" ; + cout<< "b1 Value : " << b1 << "\n" ; + printvector(h, 5,(char *) "param"); +#endif + if(is_same::value) + { + blis::rotmg( + &d1, + &d2, + &b1, + b2, + h + ); + }else{ + blis::rotmg( + &d1, + &d2, + &b1, + b2, + h_double + ); + } + +#ifdef PRINT + cout<< "After blis::rotmg \nd1 Value : " << d1 << "\n" ; + cout<< "d2 Value : " << d2 << "\n" ; + cout<< "b1 Value : " << b1 << "\n" ; + printvector(h, 5,(char *) "param"); +#endif + +#ifdef PRINT + cout<< "Expected Output from blis::rotmg \nd1 Value : " << d1_ref << "\n" ; + cout<< "d2 Value : " << d2_ref << "\n" ; + cout<< "b1 Value : " << b1_ref << "\n" ; + printvector(h_ref, 5,(char *) "param"); +#endif + if( (d1 != d1_ref ) || (d2 != d2_ref ) || (b1 != b1_ref ) ) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else if(is_same::value){ + if(computeErrorV(1, 1 , 5, h, h_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + }else if(is_same::value){ + if(computeErrorV(1, 1 , 5, h_double, h_ref_double )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + }else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_rotmg( ); + test_rotmg( ); + return 0; + +} diff --git a/vendor/testcpp/test_scal.cc b/vendor/testcpp/test_scal.cc new file mode 100644 index 0000000000..82b2821a66 --- /dev/null +++ b/vendor/testcpp/test_scal.cc @@ -0,0 +1,138 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 6 +#define ALPHA 0.5 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename TA,typename TB> +void ref_scal(int64_t n, + TA * alpha, + TB *X + ) + +{ + obj_t obj_x; + obj_t obj_alpha; + num_t dt_x , dt_alpha; + if(is_same::value) + dt_x = BLIS_FLOAT; + else if(is_same::value) + dt_x = BLIS_DOUBLE; + else if(is_same>::value) + dt_x = BLIS_SCOMPLEX; + else if(is_same>::value) + dt_x = BLIS_DCOMPLEX; + + if(is_same::value) + dt_alpha = BLIS_FLOAT; + else if(is_same::value) + dt_alpha = BLIS_DOUBLE; + else if(is_same>::value) + dt_alpha = BLIS_SCOMPLEX; + else if(is_same>::value) + dt_alpha = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt_alpha, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt_x, n, 1, X, 1, n,&obj_x ); + + bli_scalv(&obj_alpha, + &obj_x + ); + +} +template< typename TA, typename TB> +void test_scal() +{ + TB *X, *X_ref; + TA alpha = ALPHA; + int n; + int incx; + + n = N; + + incx = 1; + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + copy_buffer(X, X_ref , n ,1); + +#ifdef PRINT + printvector(X, n, (char *)"X"); +#endif + blis::scal( + n, + alpha, + X, + incx + ); + +#ifdef PRINT + printvector(X, n, (char *)"X output"); +#endif + ref_scal(n , &alpha , X_ref ); + +#ifdef PRINT + printvector(X_ref, n, (char *)"X ref output"); +#endif + if(computeErrorV(incx, incx , n, X, X_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( X ); + delete[]( X_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_scal( ); + test_scal( ); + test_scal , std::complex>( ); + test_scal , std::complex>( ); + test_scal>( ); + test_scal>( ); + return 0; + +} diff --git a/vendor/testcpp/test_sdsdot.cc b/vendor/testcpp/test_sdsdot.cc new file mode 100644 index 0000000000..c903c97d33 --- /dev/null +++ b/vendor/testcpp/test_sdsdot.cc @@ -0,0 +1,134 @@ +/* + + BLISPP + C++ test driver for BLIS CPP sdsdot routine and reference blis sdsdot routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 1 +#define ALPHA 0 + +/* + * Test application assumes matrices to be column major, non-transposed + */ + + #if 0 +template< typename T > +void ref_sdsot(int64_t n, + T alpha, + T *X, + T *Y, + T *res_ref + ) + +{ + obj_t obj_x; + obj_t obj_y; + obj_t obj_res; + obj_t obj_alpha; + num_t dt; + + if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + bli_obj_create_with_attached_buffer( dt, 1, 1, &alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, res_ref, 1, 1,&obj_res ); + + bli_ddots( &obj_x, + &obj_y, + &obj_res ); + +} +#endif + +template< typename T > +void test_sdsdot() +{ + + T X[N], Y[N]; + int n; + int incx, incy; + T res = 0, res_ref = 0; + + n = N; + + incx = 1; + incy = 1; + + //srand (time(NULL)); + //allocate_init_buffer(X , n , 1); + //allocate_init_buffer(Y , n , 1); + + X[0] = { 0.733f }; + Y[0] = { 0.825f }; + res_ref = 0.604725f; + res = blis::sdsdot( + n, + ALPHA, + X, + incx, + Y, + incy + ); + +#ifdef PRINT + printf("Dot product = %E \n", res); + +#endif + //ref_sdsot(n, aplha, X, Y , &res_ref ); + +#ifdef PRINT + printf("Ref Dot product %E \n", res_ref); +#endif + if(res != res_ref ) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_sdsdot( ); + return 0; + +} diff --git a/vendor/testcpp/test_spr.cc b/vendor/testcpp/test_spr.cc new file mode 100644 index 0000000000..edb7aa81a9 --- /dev/null +++ b/vendor/testcpp/test_spr.cc @@ -0,0 +1,97 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 2 + +/* + * Test application assumes matrices to be column major, non-transposed + */ + +template< typename T > +void test_spr( ) +{ + int n; + int incX = -1; + T alpha = -1; + + n = N; + + + T A[] = { 0.819, 0.175, -0.809 }; + T X[] = { -0.645, -0.222 }; + T A_ref[] = { 0.769716, 0.03181, -1.225025 }; + + +#ifdef PRINT + printmatrix(A, n, n, n,(char *) "A"); + printvector(X, n,(char *) "X"); +#endif + blis::spr( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incX, + A + ); + +#ifdef PRINT + printmatrix (A, n ,n, n, (char *)"A blis:spr\n"); + printmatrix(A_ref, n, n, n,(char *)"A_ref blis:spr \n"); +#endif + + if(computeErrorM(1, 1, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_spr( ); + test_spr( ); + return 0; + +} diff --git a/vendor/testcpp/test_spr2.cc b/vendor/testcpp/test_spr2.cc new file mode 100644 index 0000000000..24f364b8e1 --- /dev/null +++ b/vendor/testcpp/test_spr2.cc @@ -0,0 +1,107 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA -1.0f +#define N 2 + +/* + * Test application assumes matrices to be column major, non-transposed + */ + +template< typename T > +void test_spr2( ) +{ + int n; + int incX = -1; + int incY = -1; + T alpha; + + alpha = ALPHA; + n = N; + + T A[] = { 0.493f, -0.175f, -0.831f }; + T X[] = { -0.163f, 0.489f }; + T Y[] = { 0.154f, 0.769f }; + T A_ref[]= { -0.259082f, -0.124959f, -0.780796f }; + + + +#ifdef PRINT + printf("Matrix A\n"); + printmatrix(A, incX, n,n,(char *)"A"); + printf("Vector X \n"); + printvector(X, n, (char *)"X"); +#endif + blis::spr2( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incX, + Y, + incY, + A + ); + +#ifdef PRINT + printf("Matrix A after blis:spr2\n"); + printmatrix (A,1 ,n, n,(char *)"A"); + printf("A_ref \n"); + printmatrix(A_ref, 1, n,n,(char *)"A_ref output"); +#endif + + if(computeErrorM(1, 1, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_spr2( ); + test_spr2( ); + return 0; + +} diff --git a/vendor/testcpp/test_swap.cc b/vendor/testcpp/test_swap.cc new file mode 100644 index 0000000000..8979d90bdf --- /dev/null +++ b/vendor/testcpp/test_swap.cc @@ -0,0 +1,136 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T> +void ref_swap(int64_t n, + T *X, + T *Y + ) + +{ + obj_t obj_x, obj_y; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + bli_swapv( &obj_x, + &obj_y + ); + +} +template< typename T > +void test_swap( ) +{ + T *X, *X_ref, *Y,*Y_ref; + int n; + int incx, incy; + + n = N; + + incx = 1; + incy = 1; + + srand (time(NULL)); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(X, X_ref , n ,1); + copy_buffer(Y, Y_ref , n ,1); + +#ifdef PRINT + printvector(X, n, (char *)"X"); + printvector(Y, n, (char *)"Y"); +#endif + blis::swap( + n, + X, + incx, + Y, + incy + ); + +#ifdef PRINT + printvector(X, n, (char *)"X output"); + printvector(Y, n, (char *)"Y output"); +#endif + ref_swap(n , X_ref, Y_ref ); + +#ifdef PRINT + printvector(X_ref, n, (char *)"X ref output"); + printvector(Y_ref, n, (char *)"Y ref output"); +#endif + if((computeErrorV(incy, incy,n, Y, Y_ref )==1)||(computeErrorV(incx, incx, n, X, X_ref )==1)) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( X ); + delete[]( Y ); + delete[]( Y_ref ); + delete[]( X_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_swap( ); + test_swap( ); + test_swap>( ); + test_swap>( ); + return 0; + +} diff --git a/vendor/testcpp/test_symm.cc b/vendor/testcpp/test_symm.cc new file mode 100644 index 0000000000..b4e10398ff --- /dev/null +++ b/vendor/testcpp/test_symm.cc @@ -0,0 +1,164 @@ +/* + + BLISPP + C++ test driver for BLIS CPP symm routine and reference blis symm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define M 5 +#define N 5 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_symm(int64_t m, int64_t n, + // side_t side, + T * alpha, + T *A, + T *B, + T * beta, + T *C ) + +{ + obj_t obj_a, obj_b, obj_c; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); + bli_obj_create_with_attached_buffer( dt, m, n, B, 1,n,&obj_b ); + bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c ); + + bli_obj_set_struc( BLIS_SYMMETRIC, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a ); + bli_symm( BLIS_LEFT, + &obj_alpha, + &obj_a, + &obj_b, + &obj_beta, + &obj_c ); + +} +template< typename T > +void test_symm( ) +{ + T *A, *B, *C, *C_ref; + T alpha, beta; + int m,n; + int lda, ldb, ldc, ldc_ref; + + alpha = ALPHA; + beta = BETA; + m = M; + n = N; + + lda = m; + ldb = n; + ldc = m; + ldc_ref = m; + + srand (time(NULL)); + allocate_init_buffer(A , m , m); + allocate_init_buffer(B , m , n); + allocate_init_buffer(C , m , n); + copy_buffer(C, C_ref , m ,n); + +#ifdef PRINT + printmatrix(A, lda ,m,m, (char *)"A"); + printmatrix(B, ldb ,m,n, (char *)"B"); + printmatrix(C, ldc ,m,n, (char *)"C"); +#endif + blis::symm( + CblasColMajor, + CblasLeft, + CblasLower, + m, + n, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc + ); + +#ifdef PRINT + printmatrix(C, ldc ,m,n, (char *)"C output"); +#endif + // ref_symm(m, n, side, &alpha, A, B, &beta, C_ref); + ref_symm(m, n, &alpha, A, B, &beta, C_ref); + +#ifdef PRINT + printmatrix(C_ref, ldc_ref ,m,n, (char *)"C ref output"); +#endif + if(computeErrorM(ldc, ldc_ref, m, n, C, C_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); + + + delete[]( A ); + delete[]( B ); + delete[]( C ); + delete[]( C_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_symm( ); + test_symm( ); + test_symm>( ); + test_symm>( ); + return 0; + +} diff --git a/vendor/testcpp/test_syr.cc b/vendor/testcpp/test_syr.cc new file mode 100644 index 0000000000..327cd93947 --- /dev/null +++ b/vendor/testcpp/test_syr.cc @@ -0,0 +1,140 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_syr(int64_t n, + T * alpha, + T *X, + T *A ) + +{ + obj_t obj_a; + obj_t obj_x; + obj_t obj_alpha; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + + bli_obj_set_struc( BLIS_SYMMETRIC, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a); + bli_syr( &obj_alpha, + &obj_x, + &obj_a ); + +} +template< typename T > +void test_syr( ) +{ + T *A, *X, *A_ref; + T alpha; + int n; + int lda, incx, lda_ref; + + alpha = ALPHA; + n = N; + + lda = n; + lda_ref = n; + incx = 1; + + srand (time(NULL)); + allocate_init_buffer(A , n , n); + allocate_init_buffer(X , n , 1); + copy_buffer(A, A_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,n, (char *)"A"); + printvector(X, n,(char *) "X"); +#endif + blis::syr( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incx, + A, + lda + ); + +#ifdef PRINT + printmatrix(A, lda , n , n,(char *) "A output"); +#endif + ref_syr(n, &alpha, X, A_ref); +#ifdef PRINT + printmatrix(A_ref, lda , n, n, (char *)"A ref output"); +#endif + if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_syr( ); + test_syr( ); + return 0; + +} diff --git a/vendor/testcpp/test_syr2.cc b/vendor/testcpp/test_syr2.cc new file mode 100644 index 0000000000..165ca146f6 --- /dev/null +++ b/vendor/testcpp/test_syr2.cc @@ -0,0 +1,149 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define N 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_syr2(int64_t n, + T * alpha, + T *X, + T *Y, + T *A ) + +{ + obj_t obj_a; + obj_t obj_x, obj_y; + obj_t obj_alpha; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); + bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); + + bli_obj_set_struc( BLIS_SYMMETRIC, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a); + bli_syr2( &obj_alpha, + &obj_x, + &obj_y, + &obj_a ); + +} +template< typename T > +void test_syr2( ) +{ + T *A, *X, *Y, *A_ref; + T alpha; + int n; + int lda, incx, incy, lda_ref; + + alpha = ALPHA; + n = N; + + lda = n; + lda_ref = n; + incx = 1; + incy = 1; + srand (time(NULL)); + allocate_init_buffer(A , n , n); + allocate_init_buffer(X , n , 1); + allocate_init_buffer(Y , n , 1); + copy_buffer(A, A_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,n,(char *) "A"); + printvector(X, n, (char *)"X"); + printvector(Y, n, (char *)"Y"); +#endif + blis::syr2( + CblasColMajor, + CblasLower, + n, + alpha, + X, + incx, + Y, + incy, + A, + lda + ); + +#ifdef PRINT + printmatrix(A, lda , n , n,(char *) "A output"); +#endif + ref_syr2(n, &alpha, X, Y, A_ref); + +#ifdef PRINT + printmatrix(A_ref, lda , n, n, (char *)"A_ref output"); +#endif + if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( Y ); + delete[]( A_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_syr2( ); + test_syr2( ); + return 0; + +} diff --git a/vendor/testcpp/test_syr2k.cc b/vendor/testcpp/test_syr2k.cc new file mode 100644 index 0000000000..d56ff97a31 --- /dev/null +++ b/vendor/testcpp/test_syr2k.cc @@ -0,0 +1,163 @@ +/* + + BLISPP + C++ test driver for BLIS CPP syr2k routine and reference blis syr2k routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define N 6 +#define K 6 + +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_syr2k(int64_t n, int64_t k, + T * alpha, + T *A, + T *B, + T * beta, + T *C ) + +{ + obj_t obj_a, obj_b, obj_c; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, n, k, A, 1,n,&obj_a ); + bli_obj_create_with_attached_buffer( dt, k, n, B,1,k,&obj_b ); + bli_obj_create_with_attached_buffer( dt, n, n, C, 1,n,&obj_c ); + + bli_obj_set_struc( BLIS_SYMMETRIC, &obj_c ); + bli_obj_set_uplo( BLIS_LOWER, &obj_c ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_c ); + bli_syr2k( &obj_alpha, + &obj_a, + &obj_b, + &obj_beta, + &obj_c ); + +} +template< typename T > +void test_syr2k( ) +{ + T *A, *B, *C, *C_ref; + T alpha; + T beta; + int n,k; + int ldb, lda, ldc, ldc_ref; + + alpha = ALPHA; + beta = BETA; + k = K; + n = N; + + lda = n; + ldb = k; + ldc = n; + ldc_ref = n; + srand (time(NULL)); + allocate_init_buffer(A , n , k); + allocate_init_buffer(B , k , n); + allocate_init_buffer(C , n , n); + copy_buffer(C, C_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,k,(char *) "A"); + printmatrix(B, ldb ,k,n,(char *) "B"); + printmatrix(C, ldc ,n,n,(char *) "C"); +#endif + blis::syr2k( + CblasColMajor, + CblasLower, + CblasNoTrans, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc + ); + +#ifdef PRINT + printmatrix(C, ldc ,n,n,(char *) "C output"); +#endif + ref_syr2k(n, k, &alpha, A, B, &beta, C_ref); + +#ifdef PRINT + printmatrix(C_ref, ldc_ref ,n,n,(char *) "C ref output"); +#endif + + if(computeErrorM(ldc, ldc_ref, n, n, C, C_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( B ); + delete[]( C ); + delete[]( C_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_syr2k( ); + test_syr2k( ); + test_syr2k>( ); + test_syr2k>( ); + return 0; + +} diff --git a/vendor/testcpp/test_syrk.cc b/vendor/testcpp/test_syrk.cc new file mode 100644 index 0000000000..3defc22519 --- /dev/null +++ b/vendor/testcpp/test_syrk.cc @@ -0,0 +1,152 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define N 6 +#define K 4 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_syrk(int64_t n, int64_t k, + T * alpha, + T *A, + T * beta, + T *C ) + +{ + obj_t obj_a,obj_c; + obj_t obj_alpha, obj_beta; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); + bli_obj_create_with_attached_buffer( dt, n, k, A, 1,n,&obj_a ); + bli_obj_create_with_attached_buffer( dt, n, n, C, 1,n,&obj_c ); + + bli_obj_set_struc( BLIS_SYMMETRIC, &obj_c ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_c ); + bli_obj_set_uplo( BLIS_LOWER, &obj_c ); + bli_syrk( &obj_alpha, + &obj_a, + &obj_beta, + &obj_c ); + +} +template< typename T > +void test_syrk( ) +{ + T *A, *C, *C_ref; + T alpha, beta; + int n,k; + int lda, ldc, ldc_ref; + + alpha = ALPHA; + beta = BETA; + k = K; + n = N; + + lda = n; + ldc = n; + ldc_ref = n; + + srand (time(NULL)); + allocate_init_buffer(A , n , k); + allocate_init_buffer(C , n , n); + copy_buffer(C, C_ref , n ,n); + +#ifdef PRINT + printmatrix(A, lda ,n,k, (char *)"A"); + printmatrix(C, ldc ,n,n, (char *)"C"); +#endif + blis::syrk( + CblasColMajor, + CblasLower, + CblasNoTrans, + n, + k, + alpha, + A, + lda, + beta, + C, + ldc + ); + +#ifdef PRINT + printmatrix(C, ldc ,n,n, (char *)"C output"); +#endif + ref_syrk(n, k, &alpha, A, &beta, C_ref); + +#ifdef PRINT + printmatrix(C_ref, ldc_ref ,n,n, (char *)"C ref output"); +#endif + if(computeErrorM(ldc, ldc_ref, n, n, C, C_ref )==1) + printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( C ); + delete[]( C_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_syrk( ); + test_syrk( ); + test_syrk>( ); + test_syrk>( ); + return 0; + +} diff --git a/vendor/testcpp/test_tbmv.cc b/vendor/testcpp/test_tbmv.cc new file mode 100644 index 0000000000..ba9d565232 --- /dev/null +++ b/vendor/testcpp/test_tbmv.cc @@ -0,0 +1,103 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +//#define PRINT +#define N 3 +#define K 1 +/* + * Test application assumes matrices to be column major, non-transposed + */ + +template< typename T > +void test_tbmv( ) +{ + int n,k,lda; + + k = K; + n = N; + + + lda = n; + T A[] = { 0.439f, -0.484f, -0.952f, -0.508f, 0.381f, -0.889f, -0.192f, -0.279f, -0.155f }; + T X[] = { -0.089f, -0.688f, -0.203f }; + int incX = -1; + T X_ref[] = { -0.24504f, 0.447756f, -0.089117f }; + + +#ifdef PRINT + printmatrix(A, lda ,n,n,(char *)"A"); + printvector(X, n,(char *)"X"); +#endif + blis::tbmv( + CblasColMajor, + CblasLower, + CblasNoTrans, + CblasNonUnit, + n, + k, + A, + lda, + X, + incX + ); + +#ifdef PRINT + printvector(X, n,(char *)"X"); + printvector(X_ref ,n,(char *) "X output"); +#endif + if(computeErrorV(incX, incX, n, X, X_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_tbmv( ); + test_tbmv( ); + test_tbmv>( ); + test_tbmv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_tbsv.cc b/vendor/testcpp/test_tbsv.cc new file mode 100644 index 0000000000..85bcdb4ffd --- /dev/null +++ b/vendor/testcpp/test_tbsv.cc @@ -0,0 +1,104 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +//#define PRINT +#define K 1 +#define N 3 +/* + * Test application assumes matrices to be column major, non-transposed + */ + +template< typename T > +void test_tbsv( ) +{ + int n,k,lda; + + k = K; + n = N; + lda = n; + + T A[] = { -0.681f, 0.209f, 0.436f, -0.369f, 0.786f, -0.84f, 0.86f, -0.233f, 0.734f }; + T X[] = { -0.305f, 0.61f, -0.831f }; + int incX = -1; + T X_ref[] = { 0.524539f, -0.961964f, 1.22026f }; + + +#ifdef PRINT + printmatrix(A, lda ,n,n,(char *)"A"); + printvector(X, n,(char *) "X"); +#endif + blis::tbsv( + CblasColMajor, + CblasLower, + CblasNoTrans, + CblasNonUnit, + n, + k, + A, + lda, + X, + incX + ); + +#ifdef PRINT + printvector(X, n, (char *)"X blis::tbsv\n"); + printvector(X_ref, n,(char *) "X_ref blis::tbsv output"); + +#endif + + if(computeErrorV(1,1, n, X, X_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_tbsv( ); + test_tbsv( ); + test_tbsv>( ); + test_tbsv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_tpmv.cc b/vendor/testcpp/test_tpmv.cc new file mode 100644 index 0000000000..e2a41d34aa --- /dev/null +++ b/vendor/testcpp/test_tpmv.cc @@ -0,0 +1,84 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +#define N 2 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void test_tpmv( ) +{ + int n; + + n = N; + + T A[] = { -0.587f, 0.14f, 0.841f }; + T X[] = { -0.213f, 0.885f }; + int incX = -1; + T X_ref[] = { -0.055233f, -0.519495f }; + + blis::tpmv( + CblasColMajor, + CblasLower, + CblasNoTrans, + CblasNonUnit, + n, + A, + X, + incX + ); + + if(computeErrorV(incX, incX, n, X, X_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_tpmv( ); + test_tpmv( ); + test_tpmv>( ); + test_tpmv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_tpsv.cc b/vendor/testcpp/test_tpsv.cc new file mode 100644 index 0000000000..a9c3c2109f --- /dev/null +++ b/vendor/testcpp/test_tpsv.cc @@ -0,0 +1,87 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +#define N 2 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void test_tpsv( ) +{ + int n; + n = N; + + T A[] = { -0.381f, 0.53f, 0.451f }; + T X[] = { 0.144f, 0.032f }; + int incX = -1; + T X_ref[] = { 0.417992f, -0.0839895f }; + + + + blis::tpsv( + CblasColMajor, + CblasLower, + CblasNoTrans, + CblasNonUnit, + n, + A, + X, + incX + ); + + + if(computeErrorV(1,1, n, X, X_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_tpsv( ); + test_tpsv( ); + test_tpsv>( ); + test_tpsv>( ); + return 0; + +} diff --git a/vendor/testcpp/test_trmm.cc b/vendor/testcpp/test_trmm.cc new file mode 100644 index 0000000000..c6301f0134 --- /dev/null +++ b/vendor/testcpp/test_trmm.cc @@ -0,0 +1,153 @@ +/* + + BLISPP + C++ test driver for BLIS CPP trmm routine and reference blis trmm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define M 6 +#define N 4 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_trmm(int64_t m, int64_t n, + T * alpha, + T *A, + T *B + ) + +{ + obj_t obj_a, obj_b; + obj_t obj_alpha; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); + bli_obj_create_with_attached_buffer( dt, m, n, B, 1,m,&obj_b ); + + bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a ); + bli_trmm( BLIS_LEFT, + &obj_alpha, + &obj_a, + &obj_b + ); + +} +template< typename T > +void test_trmm( ) +{ + T *A, *B, *B_ref; + T alpha; + int m,n; + int lda, ldb, ldb_ref; + + alpha = ALPHA; + m = M; + n = N; + + lda = m; + ldb = m; + ldb_ref = m; + + srand (time(NULL)); + allocate_init_buffer(A , m , m); + allocate_init_buffer(B , m , n); + copy_buffer(B, B_ref , m ,n); + +#ifdef PRINT + printmatrix(A, lda ,m,m, (char *)"A"); + printmatrix(B, ldb ,m,n, (char *)"B"); +#endif + blis::trmm( + CblasColMajor, + CblasLeft, + CblasLower, + CblasNoTrans, + CblasNonUnit, + m, + n, + alpha, + A, + lda, + B, + ldb + ); + +#ifdef PRINT + printmatrix(B, ldb ,m,n, (char *)"B output"); +#endif + ref_trmm(m, n, &alpha, A, B_ref); + +#ifdef PRINT + printmatrix(B_ref, ldb_ref ,m,n, (char *)"B ref output"); +#endif + if(computeErrorM(ldb, ldb_ref, m, n, B, B_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + + delete[]( A ); + delete[]( B ); + delete[]( B_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_trmm( ); + test_trmm( ); + test_trmm>( ); + test_trmm>( ); + return 0; + +} diff --git a/vendor/testcpp/test_trsm.cc b/vendor/testcpp/test_trsm.cc new file mode 100644 index 0000000000..4c5ead3bcf --- /dev/null +++ b/vendor/testcpp/test_trsm.cc @@ -0,0 +1,154 @@ +/* + + BLISPP + C++ test driver for BLIS CPP trsm routine and reference blis trsm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +#define ALPHA 1.0 +#define M 5 +#define N 4 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_trsm(int64_t m, int64_t n, + T * alpha, + T *A, + T *B + ) + +{ + obj_t obj_a, obj_b; + obj_t obj_alpha; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); + bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); + bli_obj_create_with_attached_buffer( dt, m, n, B, 1,m,&obj_b ); + + bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a ); + bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a ); + bli_trsm( BLIS_LEFT, + &obj_alpha, + &obj_a, + &obj_b + ); + +} +template< typename T > +void test_trsm( ) +{ + T *A, *B, *B_ref; + T alpha; + int m,n; + int lda, ldb, ldb_ref; + + alpha = ALPHA; + m = M; + n = N; + + lda = m; + ldb = m; + ldb_ref = m; + + srand (time(NULL)); + allocate_init_buffer(A , m , m); + allocate_init_buffer(B , m , n); + copy_buffer(B, B_ref , m ,n); + +#ifdef PRINT + printmatrix(A, lda ,m,m, (char *)"A"); + printmatrix(B, ldb ,m,n, (char *)"B"); +#endif + + blis::trsm( + CblasColMajor, + CblasLeft, + CblasLower, + CblasNoTrans, + CblasNonUnit, + m, + n, + alpha, + A, + lda, + B, + ldb + ); + +#ifdef PRINT + printmatrix(B, ldb ,m,n, (char *)"B output"); +#endif + ref_trsm(m, n, &alpha, A, B_ref); + +#ifdef PRINT + printmatrix(B_ref, ldb_ref ,m,n, (char *)"B ref output"); +#endif + if(computeErrorM(ldb, ldb_ref, m, n, B, B_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + + delete[]( A ); + delete[]( B ); + delete[]( B_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_trsm( ); + test_trsm( ); + test_trsm>( ); + test_trsm>( ); + return 0; + +} diff --git a/vendor/testcpp/test_trsv.cc b/vendor/testcpp/test_trsv.cc new file mode 100644 index 0000000000..d194f097b7 --- /dev/null +++ b/vendor/testcpp/test_trsv.cc @@ -0,0 +1,142 @@ +/* + + BLISPP + C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.hh" +#include "test.hh" + +using namespace blis; +using namespace std; +//#define PRINT +//#define PRINT +#define M 5 +#define N 6 +/* + * Test application assumes matrices to be column major, non-transposed + */ +template< typename T > +void ref_trsv(int64_t n, + T *A, + T *X + ) + +{ + obj_t obj_a, obj_x; + num_t dt; + + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; + + bli_obj_create_with_attached_buffer( dt, n, n, A, 1,n,&obj_a ); + bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x ); + + bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a ); + bli_obj_set_uplo( BLIS_LOWER, &obj_a ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &obj_a ); + bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a ); + bli_trsv( &BLIS_ONE, + &obj_a, + &obj_x + ); + +} +template< typename T > +void test_trsv( ) +{ + T *A, *X, *X_ref; + int n; + int lda, incx, incx_ref; + + n = N; + + lda = n; + incx = 1; + incx_ref = 1; + + srand (time(NULL)); + allocate_init_buffer(A , n , n); + allocate_init_buffer(X , n , 1); + copy_buffer(X, X_ref , n ,1); + +#ifdef PRINT + printmatrix(A, lda ,n,n,(char *) "A"); + printvector(X, n,(char *) "X"); +#endif + blis::trsv( + CblasColMajor, + CblasLower, + CblasNoTrans, + CblasNonUnit, + n, + A, + lda, + X, + incx + ); + +#ifdef PRINT + printvector(X, n,(char *) "X output"); +#endif + ref_trsv(n, A, X_ref); + +#ifdef PRINT + printvector(X_ref, n,(char *) "X ref output"); +#endif + if(computeErrorV(incx, incx_ref, n, X, X_ref )==1) + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); + else + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); + + + delete[]( A ); + delete[]( X ); + delete[]( X_ref ); +} + +// ----------------------------------------------------------------------------- +int main( int argc, char** argv ) +{ + test_trsv( ); + test_trsv( ); + test_trsv>( ); + test_trsv>( ); + return 0; + +} diff --git a/version b/version index 4b9fcbec10..ac39a106c4 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.5.1 +0.9.0 diff --git a/windows/Makefile b/windows/Makefile deleted file mode 100644 index f015fe14ff..0000000000 --- a/windows/Makefile +++ /dev/null @@ -1,341 +0,0 @@ -# -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - - - -# -# --- Include variables determined at configure-time -------------------------- -# -CONFIGURE_DEFS = config\config.mk - -!if exist ( $(CONFIGURE_DEFS) ) -!include $(CONFIGURE_DEFS) -!else -!error nmake: $(CONFIGURE_DEFS) does not exist! Run configure.cmd first. -!endif - - - -# -# --- Include environment- and build-specific definitions ---------------------- -# - -MAKE_DEFS = build\defs.mk - -# Include build definitions -!if exist ( $(MAKE_DEFS) ) -!include $(MAKE_DEFS) -!else -!error nmake: $(MAKE_DEFS) does not exist! Your libblis distribution may be incomplete. -!endif - - - -# -# --- Variable modifications --------------------------------------------------- -# - - - -# -# --- High-level rules --------------------------------------------------------- -# - -all: libblis - -libblis: libblis-lib - -libblis-objs: $(BLIS_OBJS) - -libblis-lib: $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS_LIB) - -libblis-dll: $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS_DLL) - -lib: libblis-lib - -dll: libblis-dll - -install: install-lib install-headers - -install-lib: $(INSTALL_PREFIX_LIB)\$(LIBBLIS).lib - -install-dll: $(INSTALL_PREFIX_DLL)\$(LIBBLIS).dll \ - $(INSTALL_PREFIX_DLL)\$(LIBBLIS).lib \ - $(INSTALL_PREFIX_DLL)\$(LIBBLIS).exp - -install-headers: $(INSTALL_PREFIX_INC)\$(BLIS_H) - -clean: clean-build clean-log - -distclean: clean-config clean-build clean-log - - - -# -# --- Source code (inference) rules -------------------------------------------- -# - -# --- C source files in flamec directory --- -{$(SRC_BLI_DIRPATH)}.c{$(OBJ_BLI_DIRPATH)}.obj: -!ifdef VERBOSE - if not exist $(OBJ_BLI_DIRPATH) \ - ( $(MKDIR) $(OBJ_BLI_DIRPATH) ) - $(CC) $(CFLAGS) /c $< /Fo$@ -!else - @if not exist $(OBJ_BLI_DIRPATH) \ - ( ( $(ECHO) nmake: Creating $(OBJ_BLI_DIRPATH) directory ) & \ - ( $(MKDIR) $(OBJ_BLI_DIRPATH) ) ) - @$(ECHO) nmake: Compiling $< - @$(CC) $(CFLAGS) /c $< /Fo$@ >> $(CC_LOG_FILE) -!endif - - - -# -# --- Library generation rules ------------------------------------------------- -# - -# --- Static library --- -$(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS_LIB): libblis-objs -!ifdef VERBOSE - if not exist $(LIB_LIBBLIS_DIRPATH) \ - ( $(MKDIR) $(LIB_LIBBLIS_DIRPATH) ) - $(COPY) $(OBJ_BLI_DIRPATH)\*.obj $(LIB_LIBBLIS_DIRPATH) - $(CD) $(LIB_LIBBLIS_DIRPATH) - $(LIB) $(LIB_OPTIONS) $(LIB_BLI_OUTPUT_ARG) $(LIB_BLI_INPUT_ARGS) - $(DEL) *.obj - $(CD) $(TOP_BUILD_DIR_ABS) -!else - @if not exist $(LIB_LIBBLIS_DIRPATH) \ - ( ( $(ECHO) nmake: Creating $(LIB_LIBBLIS_DIRPATH) directory ) & \ - ( $(MKDIR) $(LIB_LIBBLIS_DIRPATH) ) ) - @$(ECHO) nmake: Creating static library $@ - @$(COPY) $(OBJ_BLI_DIRPATH)\*.obj $(LIB_LIBBLIS_DIRPATH) >> $(COPY_LOG_FILE) - @$(CD) $(LIB_LIBBLIS_DIRPATH) - @$(LIB) /VERBOSE $(LIB_OPTIONS) $(LIB_BLI_OUTPUT_ARG) $(LIB_BLI_INPUT_ARGS) - @$(DEL) *.obj - @$(CD) $(TOP_BUILD_DIR_ABS) -!endif - -# --- Dynamic library (object code file, import library, and export file) --- -$(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS_DLL): libblis-objs -!ifdef VERBOSE - if not exist $(DLL_LIBBLIS_DIRPATH) \ - ( $(MKDIR) $(DLL_LIBBLIS_DIRPATH) ) - $(COPY) $(OBJ_BLI_DIRPATH)\*.obj $(DLL_LIBBLIS_DIRPATH) >> $(COPY_LOG_FILE) - $(CD) $(DLL_LIBBLIS_DIRPATH) - $(DIR) /B *.obj > $(OBJ_LIST_FILE) - $(GENDLL) $(LIBBLIS) $(LIBBLIS) $(CC) $(LINKARGS_FILEPATH) $(SYM_DEF_FILEPATH) /objlist $(OBJ_LIST_FILE) - $(DEL) $(OBJ_LIST_FILE) - $(DEL) *.obj - $(CD) $(TOP_BUILD_DIR_ABS) -!else - @if not exist $(DLL_LIBBLIS_DIRPATH) \ - ( ( $(ECHO) nmake: Creating $(DLL_LIBBLIS_DIRPATH) directory ) & \ - ( $(MKDIR) $(DLL_LIBBLIS_DIRPATH) ) ) - @$(ECHO) nmake: Creating dynamic library $@ - @$(COPY) $(OBJ_BLI_DIRPATH)\*.obj $(DLL_LIBBLIS_DIRPATH) >> $(COPY_LOG_FILE) - @$(CD) $(DLL_LIBBLIS_DIRPATH) - @$(DIR) /B *.obj > $(OBJ_LIST_FILE) - @$(GENDLL) $(LIBBLIS) $(LIBBLIS) $(CC) $(LINKARGS_FILEPATH) $(SYM_DEF_FILEPATH) /objlist $(OBJ_LIST_FILE) - @$(DEL) $(OBJ_LIST_FILE) - @$(DEL) *.obj - @$(CD) $(TOP_BUILD_DIR_ABS) -!endif - - - -# -# --- Install rules ------------------------------------------------------------ -# - -# --- Header files --- -$(INSTALL_PREFIX_INC)\$(BLIS_H): $(INC_BLI_DIRPATH)\$(BLIS_H) \ - $(BUILD_DIRNAME)\$(BLI_CONFIG_H) -!ifdef VERBOSE - if not exist $(INSTALL_PREFIX_INC) \ - ( $(MKDIR) $(INSTALL_PREFIX_INC) ) - $(COPY) $(BUILD_DIRNAME)\$(BLI_CONFIG_H) $(INSTALL_PREFIX_INC) >> $(COPY_LOG_FILE) - $(COPY) $(INC_BLI_DIRPATH)\*.h $(INSTALL_PREFIX_INC) >> $(COPY_LOG_FILE) -!else - @if not exist $(INSTALL_PREFIX_INC) \ - ( $(MKDIR) $(INSTALL_PREFIX_INC) ) - @$(ECHO) nmake: Installing libblis header files to $(INSTALL_PREFIX_INC) - @$(COPY) $(BUILD_DIRNAME)\$(BLI_CONFIG_H) $(INSTALL_PREFIX_INC) >> $(COPY_LOG_FILE) - @$(COPY) $(INC_BLI_DIRPATH)\*.h $(INSTALL_PREFIX_INC) >> $(COPY_LOG_FILE) -!endif - -# --- Static library --- -$(INSTALL_PREFIX_LIB)\$(LIBBLIS).lib: $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS).lib -!ifdef VERBOSE - if not exist $(INSTALL_PREFIX_LIB) ( $(MKDIR) $(INSTALL_PREFIX_LIB) ) - if exist $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS).lib \ - ( $(COPY) $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS).lib $(INSTALL_PREFIX_LIB) >> $(COPY_LOG_FILE) ) -!else - @if not exist $(INSTALL_PREFIX_LIB) ( $(MKDIR) $(INSTALL_PREFIX_LIB) ) - @if exist $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS).lib \ - ( ( $(ECHO) nmake: Installing $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS).lib to $(INSTALL_PREFIX_LIB) ) & \ - ( $(COPY) $(LIB_LIBBLIS_DIRPATH)\$(LIBBLIS).lib $(INSTALL_PREFIX_LIB) >> $(COPY_LOG_FILE) ) ) -!endif - -# --- Dynamic library (object code) --- -$(INSTALL_PREFIX_DLL)\$(LIBBLIS).dll: $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).dll -!ifdef VERBOSE - if not exist $(INSTALL_PREFIX_DLL) ( $(MKDIR) $(INSTALL_PREFIX_DLL) ) - if exist $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).dll \ - ( $(COPY) $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).dll $(INSTALL_PREFIX_DLL) >> $(COPY_LOG_FILE) ) -!else - @if not exist $(INSTALL_PREFIX_DLL) ( $(MKDIR) $(INSTALL_PREFIX_DLL) ) - @if exist $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).dll \ - ( ( $(ECHO) nmake: Installing $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).dll to $(INSTALL_PREFIX_DLL) ) & \ - ( $(COPY) $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).dll $(INSTALL_PREFIX_DLL) >> $(COPY_LOG_FILE) ) ) -!endif - -# --- Dynamic library (import library) --- -$(INSTALL_PREFIX_DLL)\$(LIBBLIS).lib: $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).lib -!ifdef VERBOSE - if not exist $(INSTALL_PREFIX_DLL) ( $(MKDIR) $(INSTALL_PREFIX_DLL) ) - if exist $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).lib \ - ( $(COPY) $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).lib $(INSTALL_PREFIX_DLL) >> $(COPY_LOG_FILE) ) -!else - @if not exist $(INSTALL_PREFIX_DLL) ( $(MKDIR) $(INSTALL_PREFIX_DLL) ) - @if exist $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).lib \ - ( ( $(ECHO) nmake: Installing $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).lib to $(INSTALL_PREFIX_DLL) ) & \ - ( $(COPY) $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).lib $(INSTALL_PREFIX_DLL) >> $(COPY_LOG_FILE) ) ) -!endif - -# --- Dynamic library (export file) --- -$(INSTALL_PREFIX_DLL)\$(LIBBLIS).exp: $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).exp -!ifdef VERBOSE - if not exist $(INSTALL_PREFIX_DLL) ( $(MKDIR) $(INSTALL_PREFIX_DLL) ) - if exist $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).exp \ - ( $(COPY) $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).exp $(INSTALL_PREFIX_DLL) >> $(COPY_LOG_FILE) ) -!else - @if not exist $(INSTALL_PREFIX_DLL) ( $(MKDIR) $(INSTALL_PREFIX_DLL) ) - @if exist $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).exp \ - ( ( $(ECHO) nmake: Installing $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).exp to $(INSTALL_PREFIX_DLL) ) & \ - ( $(COPY) $(DLL_LIBBLIS_DIRPATH)\$(LIBBLIS).exp $(INSTALL_PREFIX_DLL) >> $(COPY_LOG_FILE) ) ) -!endif - - - -# -# --- Clean rules -------------------------------------------------------------- -# - -clean-log: -!ifdef VERBOSE - if exist $(CC_LOG_FILE) \ - ( $(DEL) $(CC_LOG_FILE) ) - if exist $(FC_LOG_FILE) \ - ( $(DEL) $(FC_LOG_FILE) ) - if exist $(COPY_LOG_FILE) \ - ( $(DEL) $(COPY_LOG_FILE) ) -!else - @if exist $(CC_LOG_FILE) \ - ( ( $(ECHO) nmake: Deleting $(CC_LOG_FILE) ) & \ - ( $(DEL) $(CC_LOG_FILE) ) ) - @if exist $(FC_LOG_FILE) \ - ( ( $(ECHO) nmake: Deleting $(FC_LOG_FILE) ) & \ - ( $(DEL) $(FC_LOG_FILE) ) ) - @if exist $(COPY_LOG_FILE) \ - ( ( $(ECHO) nmake: Deleting $(COPY_LOG_FILE) ) & \ - ( $(DEL) $(COPY_LOG_FILE) ) ) -!endif - -clean-config: -!ifdef VERBOSE - if exist $(CNF_DIRNAME) \ - ( $(RMDIR) $(CNF_DIRNAME) ) - if exist $(INC_DIRNAME) \ - ( $(RMDIR) $(INC_DIRNAME) ) - if exist $(SRC_DIRNAME) \ - ( $(RMDIR) $(SRC_DIRNAME) ) -!else - @if exist $(CNF_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(CNF_DIRNAME) directory ) & \ - ( $(RMDIR) $(CNF_DIRNAME) ) ) - @if exist $(INC_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(INC_DIRNAME) directory ) & \ - ( $(RMDIR) $(INC_DIRNAME) ) ) - @if exist $(SRC_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(SRC_DIRNAME) directory ) & \ - ( $(RMDIR) $(SRC_DIRNAME) ) ) -!endif - -clean-build: -!ifdef VERBOSE - if exist $(OBJ_DIRNAME) \ - ( $(RMDIR) $(OBJ_DIRNAME) ) - if exist $(LIB_DIRNAME) \ - ( $(RMDIR) $(LIB_DIRNAME) ) - if exist $(DLL_DIRNAME) \ - ( $(RMDIR) $(DLL_DIRNAME) ) -!else - @if exist $(OBJ_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(OBJ_DIRNAME) directory ) & \ - ( $(RMDIR) $(OBJ_DIRNAME) ) ) - @if exist $(LIB_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(LIB_DIRNAME) directory ) & \ - ( $(RMDIR) $(LIB_DIRNAME) ) ) - @if exist $(DLL_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(DLL_DIRNAME) directory ) & \ - ( $(RMDIR) $(DLL_DIRNAME) ) ) -!endif - -# Useful for developing when all we want to do is remove the library products. -clean-lib: -!ifdef VERBOSE - if exist $(LIB_DIRNAME) \ - ( $(RMDIR) $(LIB_DIRNAME) ) - if exist $(DLL_DIRNAME) \ - ( $(RMDIR) $(DLL_DIRNAME) ) -!else - @if exist $(LIB_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(LIB_DIRNAME) directory ) & \ - ( $(RMDIR) $(LIB_DIRNAME) ) ) - @if exist $(DLL_DIRNAME) \ - ( ( $(ECHO) nmake: Deleting $(DLL_DIRNAME) directory ) & \ - ( $(RMDIR) $(DLL_DIRNAME) ) ) -!endif - - - -# -# --- Help target -------------------------------------------------------------- -# - -help: - @$(NMAKE_HELP) - diff --git a/windows/build/bli_config.h b/windows/build/bli_config.h deleted file mode 100644 index aced5d1b74..0000000000 --- a/windows/build/bli_config.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_CONFIG_H -#define BLIS_CONFIG_H - - -// -- OPERATING SYSTEM --------------------------------------------------------- - - - -// -- FLOATING-POINT PROPERTIES ------------------------------------------------ - -#define BLIS_NUM_FP_TYPES 4 -#define BLIS_MAX_TYPE_SIZE sizeof(dcomplex) - -// Enable use of built-in C99 "float complex" and "double complex" types and -// associated overloaded operations and functions? Disabling results in -// scomplex and dcomplex being defined in terms of simple structs. -//#define BLIS_ENABLE_C99_COMPLEX - - - -// -- MULTITHREADING ----------------------------------------------------------- - -// The maximum number of BLIS threads that will run concurrently. -#define BLIS_MAX_NUM_THREADS 24 - - - -// -- MEMORY ALLOCATION -------------------------------------------------------- - -// -- Contiguous (static) memory allocator -- - -// The number of MC x KC, KC x NC, and MC x NC blocks to reserve in the -// contiguous memory pools. -#define BLIS_NUM_MC_X_KC_BLOCKS BLIS_MAX_NUM_THREADS -#define BLIS_NUM_KC_X_NC_BLOCKS 1 -#define BLIS_NUM_MC_X_NC_BLOCKS 1 - -// The maximum preload byte offset is used to pad the end of the contiguous -// memory pools so that the micro-kernel, when computing with the end of the -// last block, can exceed the bounds of the usable portion of the memory -// region without causing a segmentation fault. -#define BLIS_MAX_PRELOAD_BYTE_OFFSET 128 - -// -- Memory alignment -- - -// It is sometimes useful to define the various memory alignments in terms -// of some other characteristics of the system, such as the cache line size -// and the page size. -#define BLIS_CACHE_LINE_SIZE 64 -#define BLIS_PAGE_SIZE 4096 - -// Alignment size used to align local stack buffers within macro-kernel -// functions. -#define BLIS_STACK_BUF_ALIGN_SIZE 16 - -// Alignment size used when allocating memory dynamically from the operating -// system (eg: posix_memalign()). To disable heap alignment and just use -// malloc() instead, set this to 1. -#define BLIS_HEAP_ADDR_ALIGN_SIZE 16 - -// Alignment size used when sizing leading dimensions of dynamically -// allocated memory. -#define BLIS_HEAP_STRIDE_ALIGN_SIZE BLIS_CACHE_LINE_SIZE - -// Alignment size used when allocating entire blocks of contiguous memory -// from the contiguous memory allocator. -#define BLIS_CONTIG_ADDR_ALIGN_SIZE BLIS_PAGE_SIZE - - - -// -- MIXED DATATYPE SUPPORT --------------------------------------------------- - -// Basic (homogeneous) datatype support always enabled. - -// Enable mixed domain operations? -//#define BLIS_ENABLE_MIXED_DOMAIN_SUPPORT - -// Enable extra mixed precision operations? -//#define BLIS_ENABLE_MIXED_PRECISION_SUPPORT - - - -// -- MISCELLANEOUS OPTIONS ---------------------------------------------------- - -// Stay initialized after auto-initialization, unless and until the user -// explicitly calls bli_finalize(). -#define BLIS_ENABLE_STAY_AUTO_INITIALIZED - - - -// -- BLAS-to-BLIS COMPATIBILITY LAYER ----------------------------------------- - -// Enable the BLAS compatibility layer? -#define BLIS_ENABLE_BLAS2BLIS - -// Enable 64-bit integers in the BLAS compatibility layer? If disabled, -// these integers will be defined as 32-bit. -#define BLIS_ENABLE_BLAS2BLIS_INT64 - -// Fortran-77 name-mangling macros. -#define PASTEF770(name) name ## _ -#define PASTEF77(ch1,name) ch1 ## name ## _ -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ - - -#endif - diff --git a/windows/build/defs.mk b/windows/build/defs.mk deleted file mode 100644 index 84b52b9aee..0000000000 --- a/windows/build/defs.mk +++ /dev/null @@ -1,240 +0,0 @@ -# -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - - -# -# --- General build system options -------------------------------------------- -# - -# Uncomment this for verbose output from nmake. -# VERBOSE = 1 - -# Assign this varible to be the full path to the directory to which you would -# like the BLIS build products to be installed upon running "nmake install". -# The nmake install target will create the install directory and all requisite -# subdirectories if they do not already exist (in which case the user must have -# permission to create these directories). -INSTALL_PREFIX = c:\field\lib - - -# -# --- Important build system filenames ---------------------------------------- -# - -# DLL link arguments. The contents of this file should be customized when -# building a dynamically-linked library. The lines of the file should contain -# linker options, library names, and library paths. Note that the library -# paths must be declared in the following form: -# -# /link /LIBPATH: -# /link /LIBPATH: -# /link /LIBPATH: -# -# where , , and are library paths to add to the list -# of paths to search when the linker attempts to locate other libraries -# listed in the file. -LINKARGS_FILENAME = linkargs.txt -LINKARGS_FILEPATH = $(PWD)\$(LINKARGS_FILENAME) - -# Various log file names that capture standard output when VERBOSE is undefined. -CC_LOG_FILE = nmake-cc.log -FC_LOG_FILE = nmake-fc.log -COPY_LOG_FILE = nmake-copy.log - - -# -# --- General name and directory definitions ----------------------------------- -# - -# The relative and absolute locations of the top-level Windows build directory. -# This is the directory in which nmake is run (not the directory named "build"). -TOP_BUILD_DIR_REL = . -TOP_BUILD_DIR_ABS = $(PWD) - -# The revision string. -REV_STR = r$(REVISION) - -# The names of the libraries. -LIBBLIS_NAME_ONLY = libblis -LIBBLIS = $(LIBBLIS_NAME_ONLY)-$(ARCH_STR)-$(REV_STR) - -# Directories that reside within the top-level Windows directory. -CNF_DIRNAME = config -INC_DIRNAME = include -SRC_DIRNAME = frame -OBJ_DIRNAME = obj -LIB_DIRNAME = lib -DLL_DIRNAME = dll - -# Leaves of interest for Windows. - -# Relative directory paths to each of the above subdirectories. -INC_DIRPATH = $(TOP_BUILD_DIR_REL)\$(INC_DIRNAME) -SRC_DIRPATH = $(TOP_BUILD_DIR_REL)\$(SRC_DIRNAME) -OBJ_DIRPATH = $(TOP_BUILD_DIR_REL)\$(OBJ_DIRNAME) -LIB_DIRPATH = $(TOP_BUILD_DIR_REL)\$(LIB_DIRNAME) -DLL_DIRPATH = $(TOP_BUILD_DIR_REL)\$(DLL_DIRNAME) - -# We only have header files for flamec leaves. -INC_BLI_DIRPATH = $(INC_DIRPATH) - -# We have source code for flamec and lapack2flamec leaves. -SRC_BLI_DIRPATH = $(SRC_DIRPATH) - - -# And we have object file paths corresponding to those source leaves defined -# above. -OBJ_BLI_DIRPATH = $(OBJ_DIRPATH)\$(ARCH_STR)\$(BUILD_STR) - -# Separate directories into which we'll move object files when we create the -# static libraries. -LIB_LIBBLIS_DIRPATH = $(LIB_DIRPATH)\$(ARCH_STR)\$(BUILD_STR) - -# Separate directories into which we'll move object files when we create the -# dynamic libraries. -DLL_LIBBLIS_DIRPATH = $(DLL_DIRPATH)\$(ARCH_STR)\$(BUILD_STR) - -# The install subdirectories. -INSTALL_PREFIX_LIB = $(INSTALL_PREFIX)\libblis\lib -INSTALL_PREFIX_DLL = $(INSTALL_PREFIX)\libblis\dll -INSTALL_PREFIX_INC = $(INSTALL_PREFIX)\libblis\include-$(ARCH_STR)-$(REV_STR) - -# Definitions for important header files used in the install-headers rule. -BUILD_DIRNAME = build -BLIS_H = blis.h - - -# -# --- General shell definitions ------------------------------------------------ -# - -CD = cd -DIR = dir -COPY = copy -DEL = del /F /Q -MKDIR = mkdir -RMDIR = rd /S /Q -ECHO = echo - - -# -# --- Helper scripts ----------------------------------------------------------- -# - -NMAKE_HELP = .\build\nmake-help.cmd - - - -# -# --- Compiler-related definitions --------------------------------------------- -# - -#!include $(VERSION_FILE) - -# --- C compiler definitions --- - -WINDOWS_BUILD = BLIS_ENABLE_WINDOWS_BUILD -VERS_STR = 0.0.9 -VERSION = BLIS_VERSION_STRING=\"$(VERS_STR)\" - -!if "$(CCOMPILER_STR)"=="icl" - -!if "$(BUILD_STR)"=="debug" -CDEBUG = /Zi -COPTIM = /Od -!elseif "$(BUILD_STR)"=="release" -CDEBUG = -COPTIM = /Ox -!endif - -CC = icl.exe -CMISCFLAGS = /nologo -CLANGFLAGS = -CPPROCFLAGS = /I.\build /I$(INC_BLI_DIRPATH) /D$(WINDOWS_BUILD) /D$(VERSION) -CWARNFLAGS = /w -CDBGFLAGS = $(CDEBUG) -COPTFLAGS = $(COPTIM) -CRTIMEFLAGS = /MT -CMTHREADFLAGS = /Qopenmp -CFLAGS = $(CMISCFLAGS) $(CLANGFLAGS) $(CPPROCFLAGS) $(CWARNFLAGS) \ - $(CDBGFLAGS) $(COPTFLAGS) $(CRTIMEFLAGS) $(CMTHREADFLAGS) - -!elseif "$(CCOMPILER_STR)"=="cl" - -!if "$(BUILD_STR)"=="debug" -CDEBUG = /Zi -COPTIM = /Od -!elseif "$(BUILD_STR)"=="release" -CDEBUG = -COPTIM = /Ox -!endif - -CC = cl.exe -CMISCFLAGS = /nologo -CLANGFLAGS = -CPPROCFLAGS = /I.\build /I$(INC_BLI_DIRPATH) /D$(WINDOWS_BUILD) /D$(VERSION) -CWARNFLAGS = /w -CDBGFLAGS = $(CDEBUG) -COPTFLAGS = $(COPTIM) -CRTIMEFLAGS = /MT -CMTHREADFLAGS = /openmp -CFLAGS = $(CMISCFLAGS) $(CLANGFLAGS) $(CPPROCFLAGS) $(CWARNFLAGS) \ - $(CDBGFLAGS) $(COPTFLAGS) $(CRTIMEFLAGS) $(CMTHREADFLAGS) - -!endif - - - -# -# --- Library-related definitions ---------------------------------------------- -# - -# --- Static library definitions --- - -LIBBLIS_LIB = $(LIBBLIS).lib - -LIB = lib -LIB_OPTIONS = /nologo -LIB_BLI_OUTPUT_ARG = /out:$(LIBBLIS_LIB) -LIB_BLI_INPUT_ARGS = *.obj - -# --- Dynamic library definitions --- - -LIBBLIS_DLL = $(LIBBLIS).dll - -GENDLL = $(TOP_BUILD_DIR_ABS)\gendll.cmd -OBJ_LIST_FILE = libblis-objects.txt - -SYM_DEF_FILEPATH = $(TOP_BUILD_DIR_ABS)\$(BUILD_DIRNAME)\libblis-symbols.def - diff --git a/windows/build/gather-src-for-windows.py b/windows/build/gather-src-for-windows.py deleted file mode 100644 index e3b589b5b1..0000000000 --- a/windows/build/gather-src-for-windows.py +++ /dev/null @@ -1,351 +0,0 @@ -#! /usr/bin/env python -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - -# ------------------------------------------------------------------------------ - -# Import modules -import sys -import os -import os.path -import getopt -import shutil -import string - -# Global variables for command line options, with default settings. -script_name = "" -dry_run_flag = False -verbose_flag = False - -# Global constants -flat_config_dirname = "config" -flat_header_dirname = "include" -flat_source_dirname = "frame" -leaf_list_path = "build/leaf_list" -ignore_list_path = "build/ignore_list" -ignore_list_win_path = "build/ignore_list.windows" - -# ------------------------------------------------------------------------------ - -def print_usage(): - - # Print help information. - print " " - print " %s" % script_name - print " " - print " Field G. Van Zee" - print " " - print " Walk the BLIS source tree and copy all sources necessary for" - print " building BLIS under Windows into a single flat directory with" - print " no subdirectory hierarchy." - print " " - print " Usage:" - print " %s [options] tree_dir flat_dir" % script_name - print " " - print " The following options are accepted:" - print " " - print " -d dry-run" - print " Go through all the motions, but don't actually copy any" - print " files." - print " -v verbose" - print " Be verbose about actions (one line of output her action)." - print " " - - # Exit the script. - sys.exit() - -# ------------------------------------------------------------------------------ - -def main(): - - # Extern our global veriables. - global script_name - global dry_run_flag - global verbose_flag - - # Get the script name so we can use it in our output. - ( script_dir, script_name ) = os.path.split( sys.argv[0] ) - - try: - - # Get the command line options. - options, args = getopt.getopt( sys.argv[1:], "dv") - - except getopt.GetoptError, err: - - # print help information and exit: - print str( err ) # will print something like "option -a not recognized" - print_usage() - - # Parse our expected command line options. - print 'checking options' - for o, a in options: - - if o == "-d": - print 'found dry run' - dry_run_flag = True - elif o == "-v": - verbose_flag = True - else: - assert False, "unhandled option" - - # Check the number of arguments after command line option processing. - n_args = len( args ) - if n_args != 2: - print_usage() - - # Acquire the non-optional arguments. - tree_dir = args[0] - flat_dir = args[1] - - # Acquire the list of directories we will ignore. - ignore_list = read_ignore_list() - - # Acquire the list of leaf-type directories we will descend into. - leaf_list = read_leaf_list() - - # Create strings for each of the base subdirectories in the flat - # destination directory. - flat_config_base_dirpath = os.path.join( flat_dir, flat_config_dirname ) - flat_header_base_dirpath = os.path.join( flat_dir, flat_header_dirname ) - flat_source_base_dirpath = os.path.join( flat_dir, flat_source_dirname ) - - # Start a list of directories to create. - dirs_to_create = [] - - # Append the config directory. We do this outside of the for loop because - # we don't need subdirectories for each leaf type. - dirs_to_create.append( flat_config_base_dirpath ) - - # For each of the leaf specifications, make the full pathnames of the - # subdirectories that will reside within the root destination directory. - for leaf_spec in leaf_list: - - # Unpack the leaf_spec tuple. - src_exts, hdr_exts = leaf_spec - - # Append the directory path name to our list. - dirs_to_create.append( flat_header_base_dirpath ) - dirs_to_create.append( flat_source_base_dirpath ) - - # Iterate over the directory list we just created. - for dirpath in dirs_to_create: - - # Make the subdirectories within the root destination directory, but - # only if they are not existing directories. - if os.path.isdir( dirpath ) == False: - - # Take action only if this is not a dry run. - if dry_run_flag == False: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: creating directory %s" % ( script_name, dirpath ) - - # Make the directory, and parent directories, for dirpath. - os.makedirs( dirpath ) - - else: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: (dry-run) creating directory %s" % ( script_name, dirpath ) - - - # Walk the directory structure top-down. - for dirpath, dirnames, filenames in os.walk( tree_dir ): - - # Remove directories that appear in the ignore list. - for item in ignore_list: - if item in dirnames: - dirnames.remove( item ) - - # Consider each leaf specification. If we find the name in the directory - # path, then copy the files with its designated extensions into the flat - # source directory. - for leaf_spec in leaf_list: - - # Unpack the leaf_spec tuple. - src_exts, hdr_exts = leaf_spec - - # At this point following line can probably be removed - type_dir_name = os.sep + '' - - flat_source_leaf_dirpath = flat_source_base_dirpath - flat_header_leaf_dirpath = flat_header_base_dirpath - - if dirpath.find( type_dir_name ) != -1: - copy_files_to_flat_subdirs( dirpath, filenames, src_exts, hdr_exts, - flat_source_leaf_dirpath, - flat_header_leaf_dirpath ) - -# ------------------------------------------------------------------------------ - -def copy_files_to_flat_subdirs( dirpath, filenames, src_exts, hdr_exts, src_dirpath, hdr_dirpath ): - - # Consider all files in dirpath. - for filename in filenames: - - # Construct the full file path for the current file. - filepath = os.path.join( dirpath, filename ) - - # Iterate over the valid source extensions for the current directory - # path. - for src_ext in src_exts: - - # If the filename/filepath ends with the source extension, copy it - # to the source subdirectory within the flat destination directory. - if filepath.endswith( src_ext ): - - # Take action only if this is not a dry run. - if dry_run_flag == False: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: copying to %s from %s" % ( script_name, src_dirpath, filepath ) - - # Copy the source file to the source subdirectory. - shutil.copy2( filepath, src_dirpath ) - - else: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: (dry-run) copying to %s from %s" % ( script_name, src_dirpath, filepath ) - - # Iterate over the valid header extensions for the current directory - # path. - for hdr_ext in hdr_exts: - - # If the filename/filepath ends with the header extension, copy it - # to the include subdirectory within the flat destination directory. - if filepath.endswith( hdr_ext ): - - # Take action only if this is not a dry run. - if dry_run_flag == False: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: copying to %s from %s" % ( script_name, hdr_dirpath, filepath ) - - # Copy the header file to the header subdirectory. - shutil.copy2( filepath, hdr_dirpath ) - - else: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: (dry-run) copying to %s from %s" % ( script_name, hdr_dirpath, filepath ) - -# ------------------------------------------------------------------------------ - -def read_ignore_list(): - - # Open the ignore list files as read-only. - ignore_file = open( ignore_list_path, 'r' ) - ignore_file_win = open( ignore_list_win_path, 'r' ) - - # Read all lines in the ignore list files. The items in these lists contain - # newlines, which we'll strip out shortly. - raw_list = ignore_file.readlines() - raw_win_list = ignore_file_win.readlines() - - # Close the files. - ignore_file.close() - ignore_file_win.close() - - # Initialize an empty ignore list for the stripped version of the raw list. - ignore_list = [] - - # Iterate over the first raw list. - for line in raw_list: - - # Append the stripped line to a new list. - ignore_list.append( line.strip() ) - - # Iterate over the second raw list. - for line in raw_win_list: - - # Append the stripped line to a new list. - ignore_list.append( line.strip() ) - - # Return the list of stripped lines. - return ignore_list - -# ------------------------------------------------------------------------------ - -def read_leaf_list(): - - # Open the leaf list file. - leaf_file = open( leaf_list_path, 'r' ) - - # Read the lines in the file. - line_list = leaf_file.readlines() - - # Start with a blank list. - leaf_list = [] - - # Iterate over the lines. - for line in line_list: - - # Split the specification by colon to separate the fields. - fields = string.split( string.strip( line ), ':' ) - - # Get the individual fields of the specification. - src_exts = string.split( fields[0], ',' ) - hdr_exts = string.split( fields[1], ',' ) - - # If it's a singleton list of an empty string, make it an empty list. - if len(src_exts) == 1: - if src_exts[0] == '': - src_exts = [] - - # If it's a singleton list of an empty string, make it an empty list. - if len(hdr_exts) == 1: - if hdr_exts[0] == '': - hdr_exts = [] - - # Pack the fields into a tuple. - leaf_spec = ( src_exts, hdr_exts ) - - - # Append the tuple to our list. - leaf_list.append( leaf_spec ) - - # Return the list. - return leaf_list - -# ------------------------------------------------------------------------------ - -# Begin by executing main(). -main() diff --git a/windows/build/gen-check-rev-file.py b/windows/build/gen-check-rev-file.py deleted file mode 100644 index 20593f76b7..0000000000 --- a/windows/build/gen-check-rev-file.py +++ /dev/null @@ -1,252 +0,0 @@ -#! /usr/bin/env python -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - -# ------------------------------------------------------------------------------ - -# Import modules -import sys -import os -import os.path -import getopt - -# Global variables for command line options, with default settings. -script_name = "" -verbose_flag = False - -# Global constants -toplevel_dirpath = "." -svn_dirname = ".svn" -entries_filename = "entries" -revision_filename = "revision" -dummy_rev_string = "unknown" - - -# ------------------------------------------------------------------------------ - -def print_usage(): - - # Print help information. - print " " - print " %s" % script_name - print " " - print " Field G. Van Zee" - print " " - print " This script ensures that a revision file exists so nmake can include the" - print " revision number in the subdirectory paths to the build products." - print " " - print " If a .svn directory exists, the revision file is created (or updated)" - print " to contain the revision number contained in .svn\entries file." - print " Otherwise, if a .svn directory does not exist, the revision file is" - print " left untouched if it exists, and created with a dummy value if it does" - print " not." - print " " - print " This script is typically invoked by configure.cmd, but it can also be" - print " run manually." - print " " - print " Usage:" - print " %s" % script_name - print " " - print " The following options are accepted:" - print " " - print " -v verbose" - print " Be verbose. Output what's happening." - print " " - - # Exit the script. - sys.exit() - -# ------------------------------------------------------------------------------ - -def main(): - - # Extern our global veriables. - global script_name - global verbose_flag - - # Get the script name so we can use it in our output. - ( script_dir, script_name ) = os.path.split( sys.argv[0] ) - - try: - - # Get the command line options. - options, args = getopt.getopt( sys.argv[1:], "v") - - except getopt.GetoptError, err: - - # print help information and exit: - print str( err ) # will print something like "option -a not recognized" - print_usage() - - # Parse our expected command line options. - for o, a in options: - - if o == "-v": - verbose_flag = True - else: - assert False, "unhandled option" - - # Check the number of arguments after command line option processing. - n_args = len( args ) - if n_args != 0: - print_usage() - - # Construct the filepaths to the entries and revision files. - entries_filepath = os.path.join( toplevel_dirpath, svn_dirname, entries_filename ) - revision_filepath = os.path.join( toplevel_dirpath, revision_filename ) - - # Test for the existence of the entries file (and by proxy, a working copy). - entries_file_exists = file_exists( entries_filepath ) - - # If the entries file exists, we are in a working copy, and thus we can - # overwrite the revision file with a potentially new value. - if entries_file_exists == True: - - # Read the revision number from the entries file. - rev_num_str = read_revision_from_entries( entries_filepath ) - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: Found working copy; writing revision string \"%s\" to %s" % ( script_name, rev_num_str, revision_filepath ) - - # Write the revision number to the revision file. - write_revision_to_file( rev_num_str, revision_filepath ) - - # If we can't find the entries file, we probably are in an exported - # copy: either an official snapshot, or a copy that someone exported - # manually--hopefully (and likely) the former. - else: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: Found export. Checking for revision file..." % ( script_name ) - - # Test for the existence of the revision file. - rev_file_exists = file_exists( revision_filepath ) - - # If the revision file does not exist, create a dummy file so the - # configure script has something to work with. - if rev_file_exists == False: - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: Revision file not found. Writing dummy revision string \"%s\" to %s" % ( script_name, dummy_rev_string, revision_filepath ) - - # Write the dummy string to the revision file. - write_revision_to_file( dummy_rev_string, revision_filepath ) - - else: - - # Get the revision number from the file just for the purposes of - # being verbose, if it was requested. - rev_num_str = read_revision_file( revision_filepath ) - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: Revision file found containing revision string \"%s\". Export is valid snapshot!" % ( script_name, rev_num_str ) - - -# ------------------------------------------------------------------------------ - -def file_exists( filepath ): - - # Try to open the file read-only. - try: - - fp = open( filepath, 'r' ) - fp.close() - exists = True - - except IOError, err: - - exists = False - - return exists - - -# ------------------------------------------------------------------------------ - -def read_revision_from_entries( entries_filepath ): - - # Open the ignore list files as read-only. - entries_file = open( entries_filepath, 'r' ) - - # Read all lines in the entries file. - raw_list = entries_file.readlines() - - # Close the file. - entries_file.close() - - # Grab the fourth line, which is where the revision number lives, and strip - # it of whitespace (probably just a newline). - rev_num_str = raw_list[3].strip() - - # Return the revision number string. - return rev_num_str - -# ------------------------------------------------------------------------------ - -def write_revision_to_file( rev_string, revision_filepath ): - - # Open the revision file for writing. - revision_file = open( revision_filepath, 'w' ) - - # Write the revision string to the file. - revision_file.write( rev_string ) - - # Close the file. - revision_file.close() - -# ------------------------------------------------------------------------------ - -def read_revision_file( revision_filepath ): - - # Open the revision file. - revision_file = open( revision_filepath, 'r' ) - - # Read the first (and only) line. - line = revision_file.readline() - - # Close the file. - revision_file.close() - - # Grab the string and strip the it of whitespace (should just be a newline). - rev_num_str = line.strip() - - # Return the revision number string. - return rev_num_str - -# ------------------------------------------------------------------------------ - -# Begin by executing main(). -main() diff --git a/windows/build/gen-config-file.py b/windows/build/gen-config-file.py deleted file mode 100644 index 557083276e..0000000000 --- a/windows/build/gen-config-file.py +++ /dev/null @@ -1,360 +0,0 @@ -#! /usr/bin/env python -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - -# ------------------------------------------------------------------------------ - -# Import modules -import sys -import os -import os.path -import getopt -import re -import string - -# Global variables for command line options, with default settings. -script_name = "" -dry_run_flag = False -verbose_flag = False - -# Global constants -config_dirname = "config" -source_dirname = "frame" -object_dirname = "obj" -object_extension = ".obj" -leaf_list_path = "build/leaf_list" -revision_filename = "revision" -rev_varname = "REVISION" -pwd_varname = "PWD" -arch_varname = "ARCH_STR" -build_varname = "BUILD_STR" -ccompiler_varname = "CCOMPILER_STR" - - -# ------------------------------------------------------------------------------ - -def print_usage(): - - # Print help information. - print " " - print " %s" % script_name - print " " - print " Field G. Van Zee" - print " " - print " Create a config.mk file that is to be included by the nmake Makefile." - print " This config.mk file is based on a template, but also includes variable" - print " definitions that are needed for the specific build were are performing." - print " The variables which are currently appended to config.mk at runtime are:" - print " - the revision string" - print " - the path to the current working directory" - print " - the build string (e.g. debug, release)" - print " - the architecture string (e.g. x86, x64)" - print " - the C compiler to use (e.g. icl, cl)" - print " - a list of paths to the object files to be compiled" - print " The config.mk file is placed within the config subdirectory." - print " " - print " Usage:" - print " %s [options] flat_dir arch build ccompiler path\\to\\config.mk.in" % script_name - print " " - print " The following options are accepted:" - print " " - print " -d dry-run" - print " Go through all the motions, but don't actually output" - print " the nmake definition file." - print " -v verbose" - print " Be verbose about actions (one line of output her action)." - print " " - - # Exit the script. - sys.exit() - -# ------------------------------------------------------------------------------ - -def main(): - - # Extern our global veriables. - global script_name - global dry_run_flag - global verbose_flag - - # Get the script name so we can use it in our output. - ( script_dir, script_name ) = os.path.split( sys.argv[0] ) - - try: - - # Get the command line options. - options, args = getopt.getopt( sys.argv[1:], "dv") - - except getopt.GetoptError, err: - - # print help information and exit: - print str( err ) # will print something like "option -a not recognized" - print_usage() - - # Parse our expected command line options. - for o, a in options: - - if o == "-d": - dry_run_flag = True - elif o == "-v": - verbose_flag = True - else: - assert False, "unhandled option" - - # Check the number of arguments after command line option processing. - n_args = len( args ) - if n_args != 5: - print_usage() - - # Acquire the non-optional arguments. - flat_dir = args[0] - arch_string = args[1] - build_string = args[2] - ccompiler_string = args[3] - input_filepath = args[4] - - # Acquire the list of leaf-type directories we will descend into. - leaf_list = read_leaf_list() - - # Read the contents of the template file. - template_file_line_list = read_template_file( input_filepath ) - - # Initialize a new list for the lines to be output - output_file_line_list = template_file_line_list - - # Read the revision number from the revision file. - rev_num_str = read_revision_file( revision_filename ) - - # Add a variable for the revision number of the code we're working with. - rev_var_value = rev_varname + " = " + rev_num_str + "\n" - output_file_line_list.append( rev_var_value ) - - # Add a variable for the path to the current working directory and append - # it to our list. - pwd_var_value = pwd_varname + " = " + os.getcwd() + "\n" - output_file_line_list.append( pwd_var_value ) - - # Add a variable for the architecture string and append it to our list. - arch_var_value = arch_varname + " = " + arch_string + "\n" - output_file_line_list.append( arch_var_value ) - - # Add a variable for the build type string and append it to our list. - build_var_value = build_varname + " = " + build_string + "\n" - output_file_line_list.append( build_var_value ) - - # Add a variable for the C compiler string and append it to our list. - ccompiler_var_value = ccompiler_varname + " = " + ccompiler_string + "\n" - output_file_line_list.append( ccompiler_var_value ) - - # Walk the flat subdirectories for each of the leaves. - for leaf_spec in leaf_list: - - # Unpack the leaf_spec tuple. - src_exts, hdr_exts = leaf_spec - - # Create the paths to the source and object subdirectories. - src_dirpath = os.path.join( flat_dir, source_dirname ) - obj_dirpath = os.path.join( flat_dir, object_dirname, arch_string, build_string ) - - # Get a list of files from the leaf subdirectory. - src_filenames = os.listdir( src_dirpath ) - - # This will be the nmake variable name to which we will assign the list - # of source files. - nmake_varname = "BLIS_OBJS" - - # Generate the line to output. - leaf_line = generate_object_list( nmake_varname, src_filenames, src_exts, obj_dirpath ) - - # Accumulate the lines. - output_file_line_list.append( leaf_line ) - - # Get the filename part of the input filepath. - input_filedir, input_filename = os.path.split( input_filepath ) - - # Remove the .in extension in the output filename. - output_filename = re.sub( '.mk.in', '.mk', input_filename ) - - # Construct the filepath for the output file. - output_filepath = os.path.join( flat_dir, config_dirname, output_filename ) - - # Write the output lines. - write_output_file( output_filepath, output_file_line_list ) - -# ------------------------------------------------------------------------------ - -def read_revision_file( filepath ): - - # Try to open the revision file. - try: - - revision_file = open( filepath, 'r' ) - - except IOError, err: - - print "%s: Couldn't open revision file %s" % ( script_name, filepath ) - sys.exit(1) - - # Read the first (and only) line. - line = revision_file.readline() - - # Close the file. - revision_file.close() - - # Grab the string and strip the it of whitespace (should just be a newline). - rev_num_str = line.strip() - - # Return the revision number string. - return rev_num_str - -# ------------------------------------------------------------------------------ - -def generate_object_list( nmake_varname, src_filenames, src_exts, obj_dirpath ): - - # Initialize the string as an assignment operation. - the_line = nmake_varname + " = " - - # Return early if there are no source extensions for this leaf spec. - if src_exts == []: - return "" - - # Construct a pattern to match any file ending with any of the source file - # extensions given. This string is going to look something like ".[cf]". - src_pattern = '\.[' - for src_ext in src_exts: - src_pattern = src_pattern + src_ext - src_pattern = src_pattern + ']' - - # Consider all source files. - for src_filename in src_filenames: - - obj_filename = re.sub( src_pattern, '.obj', src_filename ) - - # Create the full path to the file. - obj_filepath = os.path.join( obj_dirpath, obj_filename ) - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: adding file %s" % ( script_name, obj_filepath ) - - # And then add it to the list. - the_line = the_line + obj_filepath + " " - - # Be verbose if verbosity was requested. - if verbose_flag == True: - print "%s: %s" % ( script_name, the_line ) - - # Append a newline to the end of the line, for file.writelines(). - the_line = the_line + "\n" - - # Return the new line. - return the_line - -# ------------------------------------------------------------------------------ - -def read_template_file( template_filepath ): - - # Open the template file as read-only. - template_file = open( template_filepath, 'r' ) - - # Read all lines in the template file. - template_file_lines = template_file.readlines() - - # Close the file. - template_file.close() - - # Return the list of lines in the template file. - return template_file_lines - -# ------------------------------------------------------------------------------ - -def write_output_file( output_filepath, output_lines ): - - # Take action only if this is not a dry run. - if dry_run_flag == False: - - # Open the template file as writable. - output_file = open( output_filepath, 'w' ) - - # Write the lines. - output_file.writelines( output_lines ) - - # Close the file. - output_file.close() - -# ------------------------------------------------------------------------------ - -def read_leaf_list(): - - # Open the leaf list file. - leaf_file = open( leaf_list_path, 'r' ) - - # Read the lines in the file. - line_list = leaf_file.readlines() - - # Start with a blank list. - leaf_list = [] - - # Iterate over the lines. - for line in line_list: - - # Split the specification by colon to separate the fields. - fields = string.split( string.strip( line ), ':' ) - - # Get the individual fields of the specification. - src_exts = string.split( fields[0], ',' ) - hdr_exts = string.split( fields[1], ',' ) - - # If it's a singleton list of an empty string, make it an empty list. - if len(src_exts) == 1: - if src_exts[0] == '': - src_exts = [] - - # If it's a singleton list of an empty string, make it an empty list. - if len(hdr_exts) == 1: - if hdr_exts[0] == '': - hdr_exts = [] - - # Pack the fields into a tuple. - leaf_spec = ( src_exts, hdr_exts ) - - # Append the tuple to our list. - leaf_list.append( leaf_spec ) - - # Return the list. - return leaf_list - -# ------------------------------------------------------------------------------ - -# Begin by executing main(). -main() diff --git a/windows/build/ignore_list b/windows/build/ignore_list deleted file mode 100644 index a8230623ed..0000000000 --- a/windows/build/ignore_list +++ /dev/null @@ -1,7 +0,0 @@ -attic -broken -old -other -temp -tmp -test diff --git a/windows/build/ignore_list.windows b/windows/build/ignore_list.windows deleted file mode 100644 index 46f8b9aacc..0000000000 --- a/windows/build/ignore_list.windows +++ /dev/null @@ -1 +0,0 @@ -.git diff --git a/windows/build/leaf_list b/windows/build/leaf_list deleted file mode 100644 index 98e115e3f6..0000000000 --- a/windows/build/leaf_list +++ /dev/null @@ -1 +0,0 @@ -c:h diff --git a/windows/build/nmake-help.cmd b/windows/build/nmake-help.cmd deleted file mode 100644 index a46ce5f1a1..0000000000 --- a/windows/build/nmake-help.cmd +++ /dev/null @@ -1,72 +0,0 @@ -:: -:: -:: BLIS -:: An object-based framework for developing high-performance BLAS-like -:: libraries. -:: -:: Copyright (C) 2014, The University of Texas at Austin -:: -:: Redistribution and use in source and binary forms, with or without -:: modification, are permitted provided that the following conditions are -:: met: -:: - Redistributions of source code must retain the above copyright -:: notice, this list of conditions and the following disclaimer. -:: - Redistributions in binary form must reproduce the above copyright -:: notice, this list of conditions and the following disclaimer in the -:: documentation and/or other materials provided with the distribution. -:: - Neither the name(s) of the copyright holder(s) nor the names of its -:: contributors may be used to endorse or promote products derived -:: from this software without specific prior written permission. -:: -:: THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -:: "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -:: LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -:: A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -:: HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -:: SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -:: LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -:: DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -:: THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -:: (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -:: OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -:: -:: - -@echo off - -echo. -echo Makefile -echo. -echo Field G. Van Zee -echo. -echo nmake Makefile for building BLIS for Microsoft Windows. nmake targets -echo may be invoked after running the configure.cmd script. Valid targets are: -echo. -echo all - Invoke the lib and dll targets. -echo lib - Build BLIS as a static library. -echo dll - Build BLIS as a dynamically-linked library. -echo help - Output help and usage information. -echo clean - Invoke clean-log and clean-build targets. -echo clean-log - Remove any log files present. -echo clean-config - Remove all products of configure.cmd. Namely, remove the -echo config, include, and src directories. -echo clean-build - Remove all products of the compilation portion of the build -echo process. Namely, remove the obj, lib, and dll directories. -echo distclean - Invoke clean-log, clean-config, and clean-build targets. -echo. -echo The Makefile also recognizes configuration options corresponding to the -echo following Makefile variables: -echo. -echo VERBOSE - When defined, nmake outputs the actual commands -echo executed instead of more concise one-line progress -echo indicators. (Undefined by default.) -echo. -echo Typically, these options are specified by commenting or uncommenting the -echo corresponding lines in the Makefile. However, if the Makefile currently does -echo not define one of the options, and you wish to enable the corresponding -echo feature without editing the Makefile, you may define the variable at the -echo command line when nmake is invoked. For example, you may enable verboseness -echo while invoking the lib target as follows: -echo. -echo nmake lib VERBOSE=1 -echo. diff --git a/windows/configure.cmd b/windows/configure.cmd deleted file mode 100644 index c2ee037d7c..0000000000 --- a/windows/configure.cmd +++ /dev/null @@ -1,87 +0,0 @@ -:: -:: -:: BLIS -:: An object-based framework for developing high-performance BLAS-like -:: libraries. -:: -:: Copyright (C) 2014, The University of Texas at Austin -:: -:: Redistribution and use in source and binary forms, with or without -:: modification, are permitted provided that the following conditions are -:: met: -:: - Redistributions of source code must retain the above copyright -:: notice, this list of conditions and the following disclaimer. -:: - Redistributions in binary form must reproduce the above copyright -:: notice, this list of conditions and the following disclaimer in the -:: documentation and/or other materials provided with the distribution. -:: - Neither the name(s) of the copyright holder(s) nor the names of its -:: contributors may be used to endorse or promote products derived -:: from this software without specific prior written permission. -:: -:: THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -:: "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -:: LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -:: A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -:: HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -:: SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -:: LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -:: DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -:: THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -:: (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -:: OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -:: -:: - -@echo off - -:ENVIRONMENT - set GEN_CHECK_REV_FILE=.\build\gen-check-rev-file.py - set GATHER_SRC=.\build\gather-src-for-windows.py - set GEN_CONFIG_FILE=.\build\gen-config-file.py - set CONFIG_DEFS_TEMPL=.\build\config.mk.in - set SRC_TREE_DIR=..\frame - set TOP_BUILD_DIR=. - -:PARAMS - if "%1"=="" (goto USAGE) - if "%2"=="" (goto USAGE) - if "%3"=="" (goto USAGE) - - set ARCH=%1 - set BUILD=%2 - set CCOMPILER=%3 - -:TASK_UNIT - echo %0: Checking/updating revision file. - %GEN_CHECK_REV_FILE% -v - echo %0: Gathering source files into local flat directories. - %GATHER_SRC% %SRC_TREE_DIR% %TOP_BUILD_DIR% - echo %0: Creating configure definitions file. - %GEN_CONFIG_FILE% %TOP_BUILD_DIR% %ARCH% %BUILD% %CCOMPILER% %CONFIG_DEFS_TEMPL% - echo %0: Configuration and setup complete. You may now run nmake. - - goto END - -:USAGE - echo. - echo configure.cmd - echo. - echo A wrapper script for various configuration and setup scripts that need - echo. to be run before nmake when building BLIS for Microsoft Windows. - echo. - echo USAGE: - echo %0 [arch] [build] [cc] - echo. - echo arch -- The architecture string to build. - echo Supported values: {x86,x64} - echo build -- The kind of build. - echo Supported values: {debug,release} - echo cc -- The C compiler to use. - echo Supported values: {icl,cl} - echo. - echo examples: - echo %0 x86 debug icl - echo %0 x64 release cl - echo. - -:END diff --git a/windows/gendll.cmd b/windows/gendll.cmd deleted file mode 100644 index db0cdc1d26..0000000000 --- a/windows/gendll.cmd +++ /dev/null @@ -1,128 +0,0 @@ -@echo off -@setlocal enabledelayedexpansion - -rem -------------------------------------------------------------------- -rem Build a dll out of a set of object files specified by the -rem argument /objlist. -rem -rem The .lib file thus created is an "import" library, which one links -rem with, but the bulk of the code ends up in the associated .dll file. -rem --------------------------------------------------------------------- - -set THIS_SCRIPT=%~dp0%~nx0 - -if "%1"=="" goto USAGE -if "%2"=="" goto USAGE -if "%3"=="" goto USAGE -if "%4"=="" goto USAGE -if "%5"=="" goto USAGE - -set gd_lib_name=%1 -set gd_link=%gd_lib_name%-static.link -set LINKER=%3 -set LINKARGSFILE=%4 -set gd_def=%5 - -:PARSE_ARGS -set IMPORT= -set OBJLIST= -:ARGLOOP -if "%6"=="" goto ENDARGLOOP -if /i not "%6"=="/import" goto OBJARG -set IMPORT=!IMPORT! %7 -goto SHIFT -:OBJARG -if /i not "%6"=="/objlist" goto ENDARGLOOP -set OBJLIST=%7 -:SHIFT -shift /4 -shift /4 -goto ARGLOOP -:ENDARGLOOP - -if defined OBJLIST goto COMPILER_SETUP -echo Error: must supply /objlist -goto USAGE - -:COMPILER_SETUP -set gd_path=%2 -set gd_dll_path=%gd_path%.dll -set gd_main_c=dll_main__%gd_lib_name%.c -set gd_main_obj=dll_main__%gd_lib_name%.obj - -rem create C file for dll_main -for /F "tokens=*" %%i in ("#include ") do echo %%i >%gd_main_c% -echo. >>%gd_main_c% -echo BOOLEAN WINAPI DllMain( >>%gd_main_c% -echo HINSTANCE hDllHandle, >>%gd_main_c% -echo DWORD nReason, >>%gd_main_c% -echo LPVOID Reserved){ >>%gd_main_c% -echo. >>%gd_main_c% -echo BOOLEAN bSuccess = TRUE;>>%gd_main_c% -echo. >>%gd_main_c% -echo switch (nReason){ >>%gd_main_c% -echo case DLL_PROCESS_ATTACH: >>%gd_main_c% -echo DisableThreadLibraryCalls( hDllHandle ); >>%gd_main_c% -echo break; >>%gd_main_c% -echo case DLL_PROCESS_DETACH: >>%gd_main_c% -echo break; >>%gd_main_c% -echo. >>%gd_main_c% -echo }; >>%gd_main_c% -echo. >>%gd_main_c% -echo return bSuccess; >>%gd_main_c% -echo }; >>%gd_main_c% -echo.>>%gd_main_c% - -rem set up link file by specifying dll filepath and main object -echo /Fe%gd_dll_path% > %gd_link% -echo %gd_main_obj% >> %gd_link% - -rem add contents of linkargs file; most of the link argument action is -rem in this file -type %LINKARGSFILE% >> %gd_link% - -rem add command-line import libraries, if any -if defined IMPORT echo !IMPORT! >> %gd_link% - -rem add export specification -echo %gd_def% >> %gd_link% - -rem add contents of OBJLIST file -type %OBJLIST% >> %gd_link% - -rem create dll, import lib, and export file -%LINKER% /nologo /c /O2 /Fo%gd_main_obj% %gd_main_c% >> gendll-cl.log -%LINKER% @%gd_link% - -:CLEANUP -del /F /Q %gd_link% %gd_main_c% %gd_main_obj% gendll-cl.log -goto END - - -:USAGE -echo. -echo. gendll.cmd -echo. -echo. Generate a dynamically-linked library from a set of object files -echo. specified in objlist_file. -echo. -echo. Usage: -echo. %0 dllname dllpath linker linkargs_file symbols_file {/import importlib} /objlist objlist_file -echo. -echo. dllname -- the name of the DLL being created, with no extension. -echo. dllpath -- the path to the DLL being created, with no extension. -echo. linker -- the compiler to use to link the DLL. -echo. linkargs_file -- the path to a file containing a list of all linker -echo. arguments--link options, libraries, and library paths-- -echo. that that may be needed to successfully link the DLL -echo. being created. -echo. symbols_file -- the path to a file containing a list of symbols to -echo. export in the DLL. -echo. importlib -- the path to a .lib library that you wish to import into -echo. the DLL being created. Optional. -echo. objlist_file -- the path to a file containing the list of object files -echo. that make up the bulk of the DLL being created. -echo. - -:END -endlocal diff --git a/windows/linkargs.txt b/windows/linkargs.txt deleted file mode 100644 index 61be998da2..0000000000 --- a/windows/linkargs.txt +++ /dev/null @@ -1,11 +0,0 @@ -/nologo -/LD /MT -/LIBPATH:"C:\Program Files\Microsoft SDKs\Windows\v6.0A\Lib" -/LIBPATH:"C:\Program Files (x86)\Microsoft Visual Studio 9.0\VC\lib" -/nodefaultlib:libcmt /nodefaultlib:libc /nodefaultlib:libmmt -msvcrt.lib -/LIBPATH:"C:\Program Files (x86)\Intel\Compiler\11.1\048\lib\ia32" -/LIBPATH:"C:\Program Files (x86)\Intel\Compiler\11.1\048\mkl\ia32\lib" -mkl_intel_c.lib -mkl_sequential.lib -mkl_core.lib diff --git a/windows/linkargs64.txt b/windows/linkargs64.txt deleted file mode 100644 index 35df4bba96..0000000000 --- a/windows/linkargs64.txt +++ /dev/null @@ -1,11 +0,0 @@ -/nologo -/LD /MT -/LIBPATH:"C:\Program Files\Microsoft SDKs\Windows\v6.0A\Lib\x64" -/LIBPATH:"C:\Program Files (x86)\Microsoft Visual Studio 9.0\VC\lib\amd64" -/nodefaultlib:libcmt /nodefaultlib:libc /nodefaultlib:libmmt -msvcrt.lib -/LIBPATH:"C:\Program Files (x86)\Intel\Compiler\11.1\048\lib\intel64" -/LIBPATH:"C:\Program Files (x86)\Intel\Compiler\11.1\048\mkl\em64t\lib" -mkl_intel_lp64.lib -mkl_sequential.lib -mkl_core.lib diff --git a/windows/revision b/windows/revision deleted file mode 100644 index 87edf799f4..0000000000 --- a/windows/revision +++ /dev/null @@ -1 +0,0 @@ -unknown \ No newline at end of file diff --git a/windows/vc110.pdb b/windows/vc110.pdb deleted file mode 100644 index 39ecfdbbb4..0000000000 Binary files a/windows/vc110.pdb and /dev/null differ